如何将新数据传递给已保存的MultinomialNB分类器?

我已经成功创建了一个MultinomialNB分类器,并将其保存为一个pickle文件以便后续使用(感谢YouTube视频:https://www.youtube.com/watch?v=0kPRaYSgblM&t=927s和其他几个视频)。以下是我的代码:

import sklearn.datasets as skdfrom sklearn.feature_extraction.text import CountVectorizerfrom sklearn.feature_extraction.text import TfidfVectorizerfrom sklearn.naive_bayes import MultinomialNBimport picklecategories = ['alt.atheism', 'soc.religion.christian', 'comp.graphics', 'sci.med']train_data = skd.load_files('E:/Python/Datasets/train', categories=categories, encoding='ISO-8859-1')test_data = skd.load_files('E:/Python/Datasets/test', categories=categories, encoding='ISO-8859-1')tf_vect = TfidfVectorizer()tfidf_train = tf_vect.fit_transform(train_data.data)clf = MultinomialNB().fit(tfidf_train, train_data.target)with open('classifier', 'wb') as picklefile:    pickle.dump(clf, picklefile)

现在,在一个单独的代码文件中,我可以将它读取到一个新的变量’new_clf’中,以便使用这个分类器处理新的文本数据:

import picklewith open('E:\Python\Text Classification\classifier', 'rb') as tm:    new_clf = pickle.load(tm)

如果我直接运行前面的文件/代码,并且已经有了包含训练数据的tf_vect变量,我可以简单地在其上转换新的文本集,并将其传递给new_clf分类器以获取预测结果。

但在我的情况下,一旦模型训练完成,我希望将其发送给另一个用户,该用户将有一个单独的代码文件,该文件需要读取分类器,然后将新文本传递给它进行预测。

我在下面的代码中遇到的问题是(以ValueError: dimension mismatch结束):-

new_text = ['God is Love', 'OpenGL is fast on GPU']new_clf.predict(new_text)

我明白我没有根据训练数据的特征转换new_text。但我无法弄清楚如何解决这个问题。

我应该创建另一个包含tf_vect的pickle文件并与用户共享吗?还是它已经包含在分类器文件中,而我错过了从分类器中获取它的过程?


回答:

你确实可以保存两个pickle文件,一个用于向量化器,另一个用于分类器。然而,最方便和推荐的解决方案是将向量化器和分类器组合成一个Pipeline对象,然后你可以对其进行pickle操作。

from sklearn.pipeline import Pipelinetf_vect = TfidfVectorizer()clf = MultinomialNB()pipe = Pipeline([("vectorizer", tf_vect), ("classifier", clf)])pipe.fit(train_data.data, train_data.target)with open('classifier', 'wb') as picklefile:    pickle.dump(pipe, picklefile)

然后,当你加载那个pickle文件时,你可以像这样使用它来处理新文本:

with open('/.../classifier', 'rb') as tm:    new_pipe = pickle.load(tm)new_pipe.predict(new_text)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注