在Keras中训练多输出模型时仅显示总损失

我正在通过Keras的函数式API实现一个自编码器模型。我的模型是多输出的,这导致每个输出都评估了一个损失函数。在训练过程中,这些损失的加权和会被最小化:

losses = [jsd for j in range(m)]  # JSD损失函数用于每个输出
autoencoder = Model(inputs, decodes)
sgd = optimizers.SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
autoencoder.compile(optimizer=sgd, loss=losses, loss_weights=[1 for k in range(m)]) # 每个输出具有相同的优先级

然后我将模型拟合到训练数据并在测试数据上进行评估:

history = autoencoder.fit(train_corr, train_attr_corr, epochs=50, batch_size=10, shuffle=True, verbose=2,                          validation_data=(test_corr, test_attr_GT))

由于verbose=2,在每个epoch结束时,训练和验证损失会在控制台显示。然而,由于模型是多输出的,所有的“子损失”都会被显示。例如:

Epoch 1/50 - 3s - loss: 0.3356 - dense_4_loss: 0.0647 - dense_5_loss: 0.0436 - dense_6_loss: 0.0391 - dense_7_loss: 0.0378 - dense_8_loss: 0.0250 - dense_9_loss: 0.0362 - val_loss: 0.1067 - val_dense_4_loss: 0.0101 - val_dense_5_loss: 0.0042 - val_dense_6_loss: 0.0031 - val_dense_7_loss: 0.0036 - val_dense_8_loss: 0.0041 - val_dense_9_loss: 0.0066

问题:是否可以只显示每个epoch的总训练损失(loss)和总验证损失val_loss

编辑:在上面的例子中,我只想显示loss: 0.3356val_loss: 0.1067


回答:

使用Keras模型的fit函数的默认详细程度选项是不可能实现的。然而,你可以通过自定义回调来实现这一点。在fit函数中使用verbosity=0禁用详细程度。定义以下回调函数,该函数覆盖了默认回调,并在epoch开始和结束时修改了输出结果。

class PrinterCallback(tf.keras.callbacks.Callback):    # def on_train_batch_begin(self, batch, logs=None):    #     # Do something on begin of training batch    def on_epoch_end(self, epoch, logs=None):        print('EPOCH: {}, Train Loss: {}, Val Loss: {}'.format(epoch,                                                               logs['loss'],                                                               logs['val_loss']))    def on_epoch_begin(self, epoch, logs=None):        print('-'*50)        print('STARTING EPOCH: {}'.format(epoch))    # def on_train_batch_end(self, batch, logs=None):    #     # Do something on end of training batch    #

在调用model.fit时,使用此回调作为callback=[PrinterCallback()]。这里还可以操作其他函数。例如,你可以在训练开始时做些什么(代码中显示了一些示例)。你可以自由修改如何打印所需的值,例如控制小数点位数。

关于Keras回调的详细信息可在此处获取,你还可以查看其他回调的源代码以实现你自己的回调。

希望这对你有帮助!

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注