Can now handle read only classifier data as well

This commit is contained in:
Yacine Filali
2017-05-24 13:22:24 -07:00
parent 4364bebf38
commit f5f7264077

View File

@@ -7,8 +7,6 @@ body belongs to the signature.
from __future__ import absolute_import from __future__ import absolute_import
import pickle
from numpy import genfromtxt from numpy import genfromtxt
from sklearn.externals import joblib from sklearn.externals import joblib
from sklearn.svm import LinearSVC from sklearn.svm import LinearSVC
@@ -32,19 +30,37 @@ def train(classifier, train_data_filename, save_classifier_filename=None):
def load(saved_classifier_filename, train_data_filename): def load(saved_classifier_filename, train_data_filename):
"""Loads saved classifier. """ """Loads saved classifier. """
import sys
if sys.version_info > (3, 0):
return load_compat(saved_classifier_filename)
return joblib.load(saved_classifier_filename)
def load_compat(saved_classifier_filename):
import os
import pickle
import tempfile
# we need to switch to the data path to properly load the related _xx.npy files
cwd = os.getcwd()
os.chdir(os.path.dirname(saved_classifier_filename))
# convert encoding using pick.load and write to temp file which we'll tell joblib to use
pickle_file = open(saved_classifier_filename, 'rb')
classifier = pickle.load(pickle_file, encoding='latin1')
try: try:
return joblib.load(saved_classifier_filename) # save our conversion if permissions allow
except ValueError: joblib.dump(classifier, saved_classifier_filename)
# load python 2 pickle format with python 3, and save it permissions allowing except Exception:
import sys # can't write to classifier, use a temp file
kwargs = {} tmp = tempfile.SpooledTemporaryFile()
if sys.version_info > (3, 0): joblib.dump(classifier, tmp)
kwargs["encoding"] = "latin1" saved_classifier_filename = tmp
loaded = pickle.load(open(saved_classifier_filename, 'rb'), **kwargs) # important, use joblib.load before switching back to original cwd
try: jb_classifier = joblib.load(saved_classifier_filename)
joblib.dump(loaded, saved_classifier_filename, compress=True) os.chdir(cwd)
except Exception:
pass
return joblib.load(saved_classifier_filename) return jb_classifier