使用生成器数据集在TensorFlow中进行多输出fit:无法正确定义形状?

我尝试将一个项目转换为使用生成器的单一网络,具有多个输出,但当使用生成器时,我无法弄清楚如何让多个输出正常工作。以下是一个最小可验证的代码片段:

import numpy as npimport tensorflow as tffrom tensorflow import kerasfrom tensorflow.keras import layers, modelsdef generate_sample():    x = list("123456789")    y = list("2345")    while 1:        yield np.array(x).astype(np.float32),[np.array(y).astype(np.float32),np.array(y).astype(np.float32)]dataset = tf.data.Dataset.from_generator(generate_sample,            output_signature=(                 tf.TensorSpec(shape=(9,), dtype=tf.float32),                 tf.TensorSpec(shape=(2,4), dtype=tf.float32)            ))dataset = dataset.batch(batch_size=32)inputs = keras.Input(shape=(next(generate_sample())[0].shape))x = layers.Dense(512, activation = "relu")(inputs)x_outputs = layers.Dense(4, activation="relu", name="output")(x)y_outputs = layers.Dense(4, activation="relu", name="output2")(x)model = keras.Model(inputs=inputs, outputs=[x_outputs,y_outputs])model.compile(loss="mse", optimizer = "adam", metrics=['accuracy'])history = model.fit(dataset, epochs=1, steps_per_epoch=10, validation_data=dataset, validation_steps=5)

这导致了一个非常长的错误,最后一部分是:

InvalidArgumentError: Incompatible shapes: [32,2,4] vs. [32,4]
[[node mean_squared_error/SquaredDifference (defined at:1) ]][Op:__inference_train_function_8957]

Function call stack: train_function

我尝试使用output_shapeoutput_signature等,以我能想到的所有方式重塑数据。但无论如何,我仍然遇到形状问题。

我在这里错过了什么明显的东西,还是在使用生成器作为数据集来源的fit中有什么问题?例如,当我从内存中加载数据时,我没有遇到这个问题。


回答:

模型的输出不是形状为(2,4)的一个张量,而是形状为(4)的两个张量。

您应该更改生成器函数以反映这一点:

def generate_sample():    x = list("123456789")    y = list("2345")    while 1:        yield np.array(x).astype(np.float32),(np.array(y).astype(np.float32),np.array(y).astype(np.float32))

以及您的输出签名:

dataset = tf.data.Dataset.from_generator(generate_sample,            output_signature=(                 tf.TensorSpec(shape=(9,), dtype=tf.float32),                 (tf.TensorSpec(shape=(4,), dtype=tf.float32),                 tf.TensorSpec(shape=(4,), dtype=tf.float32)),            ))

请注意,生成器的输出是一个嵌套的元组。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注