diff --git a/talon/signature/learning/classifier.py b/talon/signature/learning/classifier.py index 9267db0..4e1e886 100644 --- a/talon/signature/learning/classifier.py +++ b/talon/signature/learning/classifier.py @@ -7,8 +7,6 @@ body belongs to the signature. from __future__ import absolute_import -import pickle - from numpy import genfromtxt from sklearn.externals import joblib from sklearn.svm import LinearSVC @@ -32,19 +30,37 @@ def train(classifier, train_data_filename, save_classifier_filename=None): def load(saved_classifier_filename, train_data_filename): """Loads saved classifier. """ + import sys + if sys.version_info > (3, 0): + return load_compat(saved_classifier_filename) + + return joblib.load(saved_classifier_filename) + + +def load_compat(saved_classifier_filename): + import os + import pickle + import tempfile + + # we need to switch to the data path to properly load the related _xx.npy files + cwd = os.getcwd() + os.chdir(os.path.dirname(saved_classifier_filename)) + + # convert encoding using pick.load and write to temp file which we'll tell joblib to use + pickle_file = open(saved_classifier_filename, 'rb') + classifier = pickle.load(pickle_file, encoding='latin1') + try: - return joblib.load(saved_classifier_filename) - except ValueError: - # load python 2 pickle format with python 3, and save it permissions allowing - import sys - kwargs = {} - if sys.version_info > (3, 0): - kwargs["encoding"] = "latin1" + # save our conversion if permissions allow + joblib.dump(classifier, saved_classifier_filename) + except Exception: + # can't write to classifier, use a temp file + tmp = tempfile.SpooledTemporaryFile() + joblib.dump(classifier, tmp) + saved_classifier_filename = tmp - loaded = pickle.load(open(saved_classifier_filename, 'rb'), **kwargs) - try: - joblib.dump(loaded, saved_classifier_filename, compress=True) - except Exception: - pass + # important, use joblib.load before switching back to original cwd + jb_classifier = joblib.load(saved_classifier_filename) + os.chdir(cwd) - return joblib.load(saved_classifier_filename) + return jb_classifier