在 TensorFlow.Keras 中使用单一损失函数处理多输出模型

我使用了 的 Dataset,其中 y 是一个包含 6 个张量的字典,这些张量都用于一个单一的损失函数,函数如下所示:

def CustomLoss():    def custom_loss(y_true, y_pred):        a = tf.keras.losses.binary_crossentropy(y_true['a_0'], y_pred[0]) * y_true['a_1']        b = tf.square(y_true['b_0'] - y_pred[1]) * y_true['b_1']        c = tf.abs(y_true['c_0'] - y_pred[2]) * y_true['c_1']        return a + b + c    return custom_loss

我的模型有 3 个不同形状的输出。当我编译模型并调用 fit 方法时,我得到了 Value Error

model.compile(optimizer=optimizer, loss=CustomLoss())model.fit(dataset, epochs=10)
ValueError: Found unexpected keys that do not correspond to any Model output: dict_keys(['a_0', 'a_1', 'b_0', 'b_1', 'c_0', 'c_1']). Expected: ['output_0', 'output_1', 'output_2']

其中 output_0, 'output_1', 'output_2' 是输出层的名称。

我认为通过将输出层的名称设置为数据集中的键可以解决这个问题,但问题是我在数据集中有 6 个张量,而只有 3 个输出。我知道我可以为每个输出分配一个损失函数,并使用单个数据集的真实值张量,但同样,我需要至少传递两个张量作为真实值。

到目前为止,我使用了自定义训练循环,但我更希望使用 fit 方法。我使用的是 2.3.1

编辑:

示例模型:

inputs = x = tf.keras.layers.Input((256, 256, 3))x = tf.keras.applications.ResNet50(include_top=False, weights=None)(x)x1 = tf.keras.layers.Flatten()(x)x1 = tf.keras.layers.Dense(2, name='output_1')(x1)x2 = tf.keras.layers.Conv2D(256, 1, name='output_2')(x)x3 = tf.keras.layers.Flatten()(x)x3 = tf.keras.layers.Dense(64, name='output_3')(x3)model = tf.keras.Model(inputs=inputs, outputs=[x1, x2, x3])

自定义训练循环:

avg_loss = tf.keras.metrics.Mean('loss', dtype=tf.float32)for epoch in range(1, epochs+1):    for batch, (images, labels) in enumerate(train_dataset):        with tf.GradientTape() as tape:            outputs = model(images, training=False)            reg_loss = tf.reduce_sum(model.losses)            pred_loss = loss(labels, outputs)            total_loss = tf.reduce_sum(pred_loss) + reg_loss        grads = tape.gradient(total_loss, model.trainable_variables)        optimizer.apply_gradients(zip(grads, model.trainable_variables))        avg_loss.update_state(total_loss)    print(f'Epoch {epoch}/{epochs} - Loss: {avg_loss.result().numpy()}')    avg_loss.reset_states()

最小的可复现代码:

import tensorflow as tfdef CustomLoss():    def custom_loss(y_true, y_pred):        a = tf.keras.losses.binary_crossentropy(y_true['a_0'], y_pred[0]) * y_true['a_1']        b = tf.square(y_true['b_0'] - y_pred[1]) * y_true['b_1']        b = tf.reduce_sum(b, axis=(1, 2, 3))        c = tf.abs(y_true['c_0'] - y_pred[2]) * y_true['c_1']        c = tf.reduce_sum(c, axis=1)        return a + b + c    return custom_lossdataset = tf.data.Dataset.from_tensors((    tf.random.uniform((256, 256, 3)),    {'a_0': [0., 1.], 'a_1': [1.], 'b_0': tf.random.uniform((8, 8, 256)), 'b_1': [1.], 'c_0': tf.random.uniform((64,)), 'c_1': [1.]}))dataset = dataset.batch(1)inputs = x = tf.keras.layers.Input((256, 256, 3))x = tf.keras.applications.ResNet50(include_top=False, weights=None)(x)x1 = tf.keras.layers.Flatten()(x)x1 = tf.keras.layers.Dense(2, name='output_1')(x1)x2 = tf.keras.layers.Conv2D(256, 1, name='output_2')(x)x3 = tf.keras.layers.Flatten()(x)x3 = tf.keras.layers.Dense(64, name='output_3')(x3)model = tf.keras.Model(inputs=inputs, outputs=[x1, x2, x3])optimizer = tf.keras.optimizers.Adam(1e-3)model.compile(optimizer=optimizer, loss=CustomLoss())model.fit(dataset, epochs=1)

回答:

这里是针对您情况的一种方法。我们将继续使用自定义训练循环,但也利用便捷的 .fit 方法,通过自定义此方法来实现。请查看文档以了解更多关于此方法的详细信息:Customizing what happens in fit()


这里是一个简单的演示,扩展了您的可复现代码。

import tensorflow as tf# 数据集 dataset = tf.data.Dataset.from_tensors((    tf.random.uniform((256, 256, 3)),    {'a_0': [0., 1.], 'a_1': [1.], 'b_0': tf.random.uniform((8, 8, 256)),     'b_1': [1.], 'c_0': tf.random.uniform((64,)), 'c_1': [1.]}))dataset = dataset.batch(1)# 自定义损失函数 def loss(y_true, y_pred):        a = tf.keras.losses.binary_crossentropy(y_true['a_0'], y_pred[0]) * y_true['a_1']        b = tf.square(y_true['b_0'] - y_pred[1]) * y_true['b_1']        b = tf.reduce_sum(b, axis=(1, 2, 3))        c = tf.abs(y_true['c_0'] - y_pred[2]) * y_true['c_1']        c = tf.reduce_sum(c, axis=1)        return a + b + c

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注