Files
talon/talon/signature/learning/classifier.py
Matt Dietz d37c4fd551 Drops Python 2 support
REP-1030

In addition to some python 2 => 3 fixes, this change bumps the scikit-learn
version to latest. The previously pinned version of scikit-learn failed trying
to compile all necessary C modules under python 3.7+ due to included header files
that weren't compatible with C the API implemented in python 3.7+.

Simultaneously, with the restrictive compatibility supported by scikit-learn,
it seemed prudent to drop python 2 support altogether. Otherwise, we'd be stuck
with python 3.4 as the newest possible version we could support.

With this change, tests are currently passing under 3.9.2.

Lastly, imports the original training data. At some point, a new version
of the training data was committed to the repo but no classifier was
trained from it. Using a classifier trained from this new data resulted
in most of the tests failing.
2021-06-10 14:03:25 -05:00

70 lines
2.1 KiB
Python

# -*- coding: utf-8 -*-
"""The module's functions could init, train, save and load a classifier.
The classifier could be used to detect if a certain line of the message
body belongs to the signature.
"""
from __future__ import absolute_import
from numpy import genfromtxt
import joblib
from sklearn.svm import LinearSVC
def init():
"""Inits classifier with optimal options."""
return LinearSVC(C=10.0)
def train(classifier, train_data_filename, save_classifier_filename=None):
"""Trains and saves classifier so that it could be easily loaded later."""
file_data = genfromtxt(train_data_filename, delimiter=",")
train_data, labels = file_data[:, :-1], file_data[:, -1]
classifier.fit(train_data, labels)
if save_classifier_filename:
joblib.dump(classifier, save_classifier_filename)
return classifier
def load(saved_classifier_filename, train_data_filename):
"""Loads saved classifier. """
try:
return joblib.load(saved_classifier_filename)
except Exception:
import sys
if sys.version_info > (3, 0):
return load_compat(saved_classifier_filename)
raise
def load_compat(saved_classifier_filename):
import os
import pickle
import tempfile
# we need to switch to the data path to properly load the related _xx.npy files
cwd = os.getcwd()
os.chdir(os.path.dirname(saved_classifier_filename))
# convert encoding using pick.load and write to temp file which we'll tell joblib to use
pickle_file = open(saved_classifier_filename, 'rb')
classifier = pickle.load(pickle_file, encoding='latin1')
try:
# save our conversion if permissions allow
joblib.dump(classifier, saved_classifier_filename)
except Exception:
# can't write to classifier, use a temp file
tmp = tempfile.SpooledTemporaryFile()
joblib.dump(classifier, tmp)
saved_classifier_filename = tmp
# important, use joblib.load before switching back to original cwd
jb_classifier = joblib.load(saved_classifier_filename)
os.chdir(cwd)
return jb_classifier