Keras + TensorFlow 实时训练图表

我在 Jupyter 笔记本中运行以下代码:

# Visualize training historyfrom keras.models import Sequentialfrom keras.layers import Denseimport matplotlib.pyplot as pltimport numpy# fix random seed for reproducibilityseed = 7numpy.random.seed(seed)# load pima indians datasetdataset = numpy.loadtxt("pima-indians-diabetes.csv", delimiter=",")# split into input (X) and output (Y) variablesX = dataset[:,0:8]Y = dataset[:,8]# create modelmodel = Sequential()model.add(Dense(12, input_dim=8, kernel_initializer='uniform', activation='relu'))model.add(Dense(8, kernel_initializer='uniform', activation='relu'))model.add(Dense(1, kernel_initializer='uniform', activation='sigmoid'))# Compile modelmodel.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])# Fit the modelhistory = model.fit(X, Y, validation_split=0.33, epochs=150, batch_size=10, verbose=0)# list all data in historyprint(history.history.keys())# summarize history for accuracyplt.plot(history.history['acc'])plt.plot(history.history['val_acc'])plt.title('model accuracy')plt.ylabel('accuracy')plt.xlabel('epoch')plt.legend(['train', 'test'], loc='upper left')plt.show()# summarize history for lossplt.plot(history.history['loss'])plt.plot(history.history['val_loss'])plt.title('model loss')plt.ylabel('loss')plt.xlabel('epoch')plt.legend(['train', 'test'], loc='upper left')plt.show()

这段代码收集了训练过程中的纪录,然后显示了进度历史。


Q: 我如何在训练过程中实时更新图表,以便我能看到变化?


回答:

有一个名为 livelossplot 的 Python 包,可以在 Jupyter Notebook 中为 Keras 提供实时训练损失图表(声明:我是作者)。

from livelossplot import PlotLossesKerasmodel.fit(X_train, Y_train,          epochs=10,          validation_data=(X_test, Y_test),          callbacks=[PlotLossesKeras()],          verbose=0)

要了解它的工作原理,可以查看其源代码,特别是这个文件:https://github.com/stared/livelossplot/blob/master/livelossplot/outputs/matplotlib_plot.py (from IPython.display import clear_outputclear_output(wait=True))。

公平声明:它会干扰 Keras 的输出

enter image description here

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注