Can now handle read only classifier data as well
This commit is contained in:
@@ -7,8 +7,6 @@ body belongs to the signature.
|
|||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
|
|
||||||
import pickle
|
|
||||||
|
|
||||||
from numpy import genfromtxt
|
from numpy import genfromtxt
|
||||||
from sklearn.externals import joblib
|
from sklearn.externals import joblib
|
||||||
from sklearn.svm import LinearSVC
|
from sklearn.svm import LinearSVC
|
||||||
@@ -32,19 +30,37 @@ def train(classifier, train_data_filename, save_classifier_filename=None):
|
|||||||
|
|
||||||
def load(saved_classifier_filename, train_data_filename):
|
def load(saved_classifier_filename, train_data_filename):
|
||||||
"""Loads saved classifier. """
|
"""Loads saved classifier. """
|
||||||
|
import sys
|
||||||
|
if sys.version_info > (3, 0):
|
||||||
|
return load_compat(saved_classifier_filename)
|
||||||
|
|
||||||
|
return joblib.load(saved_classifier_filename)
|
||||||
|
|
||||||
|
|
||||||
|
def load_compat(saved_classifier_filename):
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import tempfile
|
||||||
|
|
||||||
|
# we need to switch to the data path to properly load the related _xx.npy files
|
||||||
|
cwd = os.getcwd()
|
||||||
|
os.chdir(os.path.dirname(saved_classifier_filename))
|
||||||
|
|
||||||
|
# convert encoding using pick.load and write to temp file which we'll tell joblib to use
|
||||||
|
pickle_file = open(saved_classifier_filename, 'rb')
|
||||||
|
classifier = pickle.load(pickle_file, encoding='latin1')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return joblib.load(saved_classifier_filename)
|
# save our conversion if permissions allow
|
||||||
except ValueError:
|
joblib.dump(classifier, saved_classifier_filename)
|
||||||
# load python 2 pickle format with python 3, and save it permissions allowing
|
except Exception:
|
||||||
import sys
|
# can't write to classifier, use a temp file
|
||||||
kwargs = {}
|
tmp = tempfile.SpooledTemporaryFile()
|
||||||
if sys.version_info > (3, 0):
|
joblib.dump(classifier, tmp)
|
||||||
kwargs["encoding"] = "latin1"
|
saved_classifier_filename = tmp
|
||||||
|
|
||||||
loaded = pickle.load(open(saved_classifier_filename, 'rb'), **kwargs)
|
# important, use joblib.load before switching back to original cwd
|
||||||
try:
|
jb_classifier = joblib.load(saved_classifier_filename)
|
||||||
joblib.dump(loaded, saved_classifier_filename, compress=True)
|
os.chdir(cwd)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return joblib.load(saved_classifier_filename)
|
return jb_classifier
|
||||||
|
|||||||
Reference in New Issue
Block a user