我已经构建了一个分类器,并且希望将其保存以供将来使用。该分类器包括不同的算法(逻辑回归、朴素贝叶斯、支持向量机):
X, y = tfidf(df, ngrams = 1)X, y = under_sample.fit_resample(X, y)X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=40)df_result = df_result.append(training_naive(X_train, X_test, y_train, y_test), ignore_index = True)df_result = df_result.append(training_logreg(X_train, X_test, y_train, y_test), ignore_index = True)df_result = df_result.append(training_svm(X_train, X_test, y_train, y_test), ignore_index = True)
这是我代码中的最后一步,我在这里比较了不同的算法。training_svm/logreg和naive都是函数。例如,training_svm定义如下:
def training_svm(X_train_log, X_test_log, y_train_log, y_test_log): folds = StratifiedKFold(n_splits = 3, shuffle = True, random_state = 40) clf = svm.SVC(kernel='linear') # Linear Kernel clf.fit(X_train_log, y_train_log) res = pd.DataFrame(columns = ['Preprocessing', 'Model', 'Precision', 'Recall', 'F1-score', 'Accuracy']) y_pred = clf.predict(X_test_log) f1 = f1_score(y_pred, y_test_log, average = 'weighted') pres = precision_score(y_pred, y_test_log, average = 'weighted') rec = recall_score(y_pred, y_test_log, average = 'weighted') acc = accuracy_score(y_pred, y_test_log) res = res.append({'Model': f'SVM', 'Precision': pres, 'Recall': rec, 'F1-score': f1, 'Accuracy': acc}, ignore_index = True) return res
由于我想用新数据来使用和测试它,我在考虑如何保存并重新使用它。我认为我应该做类似这样的事情
请问如何将其扩展到我的项目中?
回答:
正如sklearn所述:
可以使用Python内置的持久化模型,即pickle,来保存scikit-learn中的模型
示例:
from sklearn import svmfrom sklearn import datasetsclf = svm.SVC()X, y= datasets.load_iris(return_X_y=True)clf.fit(X, y)import pickles = pickle.dumps(clf)clf2 = pickle.loads(s)clf2.predict(X[0:1])
然后你可以在你的代码中为每个模型包含它,创建一个名为的函数
def predict_svm(to_predict): with open("'your_svm_model'",'rb') as f_input: clf = pickle.loads(f_input) # 或许可以使用单例模式来减少多次预测时的加载 return clf.predict(to_predict)
不过,sklearn建议使用joblib
:
在scikit-learn的特定情况下,使用joblib替换pickle(dump & load)可能更好,它在处理内部包含大型numpy数组的对象时效率更高,这对于已拟合的scikit-learn估计器通常是这样的情况,但只能将数据存储到磁盘而不是字符串:
from joblib import dump, loaddump(clf, 'filename.joblib') clf = load('filename.joblib')