Can now handle read only classifier data as well
This commit is contained in:
@@ -7,8 +7,6 @@ body belongs to the signature.
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import pickle
|
||||
|
||||
from numpy import genfromtxt
|
||||
from sklearn.externals import joblib
|
||||
from sklearn.svm import LinearSVC
|
||||
@@ -32,19 +30,37 @@ def train(classifier, train_data_filename, save_classifier_filename=None):
|
||||
|
||||
def load(saved_classifier_filename, train_data_filename):
|
||||
"""Loads saved classifier. """
|
||||
try:
|
||||
return joblib.load(saved_classifier_filename)
|
||||
except ValueError:
|
||||
# load python 2 pickle format with python 3, and save it permissions allowing
|
||||
import sys
|
||||
kwargs = {}
|
||||
if sys.version_info > (3, 0):
|
||||
kwargs["encoding"] = "latin1"
|
||||
|
||||
loaded = pickle.load(open(saved_classifier_filename, 'rb'), **kwargs)
|
||||
try:
|
||||
joblib.dump(loaded, saved_classifier_filename, compress=True)
|
||||
except Exception:
|
||||
pass
|
||||
return load_compat(saved_classifier_filename)
|
||||
|
||||
return joblib.load(saved_classifier_filename)
|
||||
|
||||
|
||||
def load_compat(saved_classifier_filename):
|
||||
import os
|
||||
import pickle
|
||||
import tempfile
|
||||
|
||||
# we need to switch to the data path to properly load the related _xx.npy files
|
||||
cwd = os.getcwd()
|
||||
os.chdir(os.path.dirname(saved_classifier_filename))
|
||||
|
||||
# convert encoding using pick.load and write to temp file which we'll tell joblib to use
|
||||
pickle_file = open(saved_classifier_filename, 'rb')
|
||||
classifier = pickle.load(pickle_file, encoding='latin1')
|
||||
|
||||
try:
|
||||
# save our conversion if permissions allow
|
||||
joblib.dump(classifier, saved_classifier_filename)
|
||||
except Exception:
|
||||
# can't write to classifier, use a temp file
|
||||
tmp = tempfile.SpooledTemporaryFile()
|
||||
joblib.dump(classifier, tmp)
|
||||
saved_classifier_filename = tmp
|
||||
|
||||
# important, use joblib.load before switching back to original cwd
|
||||
jb_classifier = joblib.load(saved_classifier_filename)
|
||||
os.chdir(cwd)
|
||||
|
||||
return jb_classifier
|
||||
|
||||
Reference in New Issue
Block a user