如何在scikit-learn训练过程中显示损失值?

我想在训练过程中检查我的损失值,以便观察每一次迭代的损失。到目前为止,我还没有找到一种简单的方法让scikit-learn提供损失值的历史记录,也没有找到scikit-learn内置的功能来为我绘制损失图。

如果没有办法绘制这个图表,那么在分类器.fit结束时能够简单地获取最终的损失值就很好了。

注意:我知道有些解决方案是闭合形式的。我使用了几种没有解析解的分类器,例如逻辑回归和支持向量机(SVM)。

有谁有任何建议吗?


回答:

所以我找不到直接获取每一次迭代损失值的很好的文档,但我希望这对未来某人有所帮助:

old_stdout = sys.stdout
sys.stdout = mystdout = StringIO()
clf = SGDClassifier(**kwargs, verbose=1)
clf.fit(X_tr, y_tr)
sys.stdout = old_stdout
loss_history = mystdout.getvalue()
loss_list = []
for line in loss_history.split('\n'):
    if(len(line.split("loss: ")) == 1):
        continue
    loss_list.append(float(line.split("loss: ")[-1]))
plt.figure()
plt.plot(np.arange(len(loss_list)), loss_list)
plt.savefig("warmstart_plots/pure_SGD:"+str(kwargs)+".png")
plt.xlabel("时间(轮次)")
plt.ylabel("损失")
plt.close()

这段代码会处理一个普通的SGDClassifier(几乎任何线性分类器),并拦截verbose=1标志,然后拆分以从详细输出中获取损失值。显然这样做会更慢,但会给我们损失值并打印出来。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注