我正在尝试通过对newsgroup20数据集进行实验来学习。我的训练模型运行得很好,问题出在预测部分。现在我想做的是在一个函数中保存训练模型(使用pickle),然后在另一个函数中对pickle数据进行预测。我找到的所有教程都告诉我如何保存和加载pickle文件,但没有告诉我如何提取X_train和y_train。如果有人能帮我解决这个问题,我将不胜感激。以下是我的代码
def classifier():
twenty_train = fetch_20newsgroups(subset='train', shuffle=True, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(twenty_train.data, twenty_train.target, test_size=0.4, random_state=0)
naive_clf = Pipeline([('vect', CountVectorizer()),
('tfidf', TfidfTransformer()),
('clf', MultinomialNB()),
])
naive_clf.fit(X_train, y_train)
filename = 'finalized_model.sav'
pickle.dump(naive_clf, open(filename, 'wb'))
def predictions():
# 需要帮助的前三行和最后的打印语句
loaded_model = pickle.load(open('finalized_model.sav', 'rb'))
result = loaded_model.score(X_test, y_test)
print(result)
# 将我的文件解析为字符串以进行预测(运行正常)
with open("/home/ubuntu/Desktop/text_classifier/dataset/predict/file,txt", "r") as myfile:
file=myfile.readlines()
file = ''.join(file)
print('根据朴素贝叶斯分类,属于类别 {} '.format(twenty_train.target_names[loaded_model.predict([file])[0]]))
回答:
当你使用pickle保存模型时,你只保存了模型本身,而不是用于训练的数据。因此,如果你想用pickle加载数据,你需要单独保存数据。例如:
data = {'train': X_train, 'target': y_train}
with open('data.pkl', 'wb') as f:
pickle.dump(data, f)
with open('data.pkl', 'rb') as f:
data = pickle.load(f)
X_train = data['train']
y_train = data['target']