REP-1030 In addition to some python 2 => 3 fixes, this change bumps the scikit-learn version to latest. The previously pinned version of scikit-learn failed trying to compile all necessary C modules under python 3.7+ due to included header files that weren't compatible with C the API implemented in python 3.7+. Simultaneously, with the restrictive compatibility supported by scikit-learn, it seemed prudent to drop python 2 support altogether. Otherwise, we'd be stuck with python 3.4 as the newest possible version we could support. With this change, tests are currently passing under 3.9.2. Lastly, imports the original training data. At some point, a new version of the training data was committed to the repo but no classifier was trained from it. Using a classifier trained from this new data resulted in most of the tests failing.
70 lines
2.1 KiB
Python
70 lines
2.1 KiB
Python
# -*- coding: utf-8 -*-
|
|
|
|
"""The module's functions could init, train, save and load a classifier.
|
|
The classifier could be used to detect if a certain line of the message
|
|
body belongs to the signature.
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
|
|
from numpy import genfromtxt
|
|
import joblib
|
|
from sklearn.svm import LinearSVC
|
|
|
|
|
|
def init():
|
|
"""Inits classifier with optimal options."""
|
|
return LinearSVC(C=10.0)
|
|
|
|
|
|
def train(classifier, train_data_filename, save_classifier_filename=None):
|
|
"""Trains and saves classifier so that it could be easily loaded later."""
|
|
file_data = genfromtxt(train_data_filename, delimiter=",")
|
|
train_data, labels = file_data[:, :-1], file_data[:, -1]
|
|
classifier.fit(train_data, labels)
|
|
|
|
if save_classifier_filename:
|
|
joblib.dump(classifier, save_classifier_filename)
|
|
return classifier
|
|
|
|
|
|
def load(saved_classifier_filename, train_data_filename):
|
|
"""Loads saved classifier. """
|
|
try:
|
|
return joblib.load(saved_classifier_filename)
|
|
except Exception:
|
|
import sys
|
|
if sys.version_info > (3, 0):
|
|
return load_compat(saved_classifier_filename)
|
|
|
|
raise
|
|
|
|
|
|
def load_compat(saved_classifier_filename):
|
|
import os
|
|
import pickle
|
|
import tempfile
|
|
|
|
# we need to switch to the data path to properly load the related _xx.npy files
|
|
cwd = os.getcwd()
|
|
os.chdir(os.path.dirname(saved_classifier_filename))
|
|
|
|
# convert encoding using pick.load and write to temp file which we'll tell joblib to use
|
|
pickle_file = open(saved_classifier_filename, 'rb')
|
|
classifier = pickle.load(pickle_file, encoding='latin1')
|
|
|
|
try:
|
|
# save our conversion if permissions allow
|
|
joblib.dump(classifier, saved_classifier_filename)
|
|
except Exception:
|
|
# can't write to classifier, use a temp file
|
|
tmp = tempfile.SpooledTemporaryFile()
|
|
joblib.dump(classifier, tmp)
|
|
saved_classifier_filename = tmp
|
|
|
|
# important, use joblib.load before switching back to original cwd
|
|
jb_classifier = joblib.load(saved_classifier_filename)
|
|
os.chdir(cwd)
|
|
|
|
return jb_classifier
|