Can now handle read only classifier data as well

This commit is contained in:
Yacine Filali
2017-05-24 13:22:24 -07:00
parent 4364bebf38
commit f5f7264077

View File

@@ -7,8 +7,6 @@ body belongs to the signature.
from __future__ import absolute_import
import pickle
from numpy import genfromtxt
from sklearn.externals import joblib
from sklearn.svm import LinearSVC
@@ -32,19 +30,37 @@ def train(classifier, train_data_filename, save_classifier_filename=None):
def load(saved_classifier_filename, train_data_filename):
"""Loads saved classifier. """
try:
return joblib.load(saved_classifier_filename)
except ValueError:
# load python 2 pickle format with python 3, and save it permissions allowing
import sys
kwargs = {}
if sys.version_info > (3, 0):
kwargs["encoding"] = "latin1"
loaded = pickle.load(open(saved_classifier_filename, 'rb'), **kwargs)
try:
joblib.dump(loaded, saved_classifier_filename, compress=True)
except Exception:
pass
return load_compat(saved_classifier_filename)
return joblib.load(saved_classifier_filename)
def load_compat(saved_classifier_filename):
import os
import pickle
import tempfile
# we need to switch to the data path to properly load the related _xx.npy files
cwd = os.getcwd()
os.chdir(os.path.dirname(saved_classifier_filename))
# convert encoding using pick.load and write to temp file which we'll tell joblib to use
pickle_file = open(saved_classifier_filename, 'rb')
classifier = pickle.load(pickle_file, encoding='latin1')
try:
# save our conversion if permissions allow
joblib.dump(classifier, saved_classifier_filename)
except Exception:
# can't write to classifier, use a temp file
tmp = tempfile.SpooledTemporaryFile()
joblib.dump(classifier, tmp)
saved_classifier_filename = tmp
# important, use joblib.load before switching back to original cwd
jb_classifier = joblib.load(saved_classifier_filename)
os.chdir(cwd)
return jb_classifier