From f5f726407717b6f7c56536c3f75638df0d28f5b6 Mon Sep 17 00:00:00 2001 From: Yacine Filali Date: Wed, 24 May 2017 13:22:24 -0700 Subject: [PATCH] Can now handle read only classifier data as well --- talon/signature/learning/classifier.py | 46 +++++++++++++++++--------- 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/talon/signature/learning/classifier.py b/talon/signature/learning/classifier.py index 9267db0..4e1e886 100644 --- a/talon/signature/learning/classifier.py +++ b/talon/signature/learning/classifier.py @@ -7,8 +7,6 @@ body belongs to the signature. from __future__ import absolute_import -import pickle - from numpy import genfromtxt from sklearn.externals import joblib from sklearn.svm import LinearSVC @@ -32,19 +30,37 @@ def train(classifier, train_data_filename, save_classifier_filename=None): def load(saved_classifier_filename, train_data_filename): """Loads saved classifier. """ + import sys + if sys.version_info > (3, 0): + return load_compat(saved_classifier_filename) + + return joblib.load(saved_classifier_filename) + + +def load_compat(saved_classifier_filename): + import os + import pickle + import tempfile + + # we need to switch to the data path to properly load the related _xx.npy files + cwd = os.getcwd() + os.chdir(os.path.dirname(saved_classifier_filename)) + + # convert encoding using pick.load and write to temp file which we'll tell joblib to use + pickle_file = open(saved_classifier_filename, 'rb') + classifier = pickle.load(pickle_file, encoding='latin1') + try: - return joblib.load(saved_classifier_filename) - except ValueError: - # load python 2 pickle format with python 3, and save it permissions allowing - import sys - kwargs = {} - if sys.version_info > (3, 0): - kwargs["encoding"] = "latin1" + # save our conversion if permissions allow + joblib.dump(classifier, saved_classifier_filename) + except Exception: + # can't write to classifier, use a temp file + tmp = tempfile.SpooledTemporaryFile() + joblib.dump(classifier, tmp) + saved_classifier_filename = tmp - loaded = pickle.load(open(saved_classifier_filename, 'rb'), **kwargs) - try: - joblib.dump(loaded, saved_classifier_filename, compress=True) - except Exception: - pass + # important, use joblib.load before switching back to original cwd + jb_classifier = joblib.load(saved_classifier_filename) + os.chdir(cwd) - return joblib.load(saved_classifier_filename) + return jb_classifier