如何调整TFRecordDataset以符合模型API要求?

我正在基于这个用于噪声抑制的代码构建模型。原版实现的问题在于它一次性加载所有数据,这在训练数据非常大时并不是一个好主意;我在链接代码中标记为training.h5的输入文件超过30 GB。

我决定改用tf.data接口,这应该能让我处理大型数据集;我的问题在于我不知道如何正确调整TFRecordDataset以满足模型API的要求。

如果你查看model.fit(x_train, [y_train, vad_train],它本质上需要以下内容:

  • x_train,形状为[nb_sequences, window, 42]
  • y_train,形状为[nb_sequences, window, 22]
  • vad_train,形状为[nb_sequences, window, 1]

window通常是固定的(在代码中为2000),所以唯一变量nb_sequences取决于你的数据集大小。然而,使用tf.data时,我们不提供xy,而只提供x(参见模型API文档)。

将tfrecord保存到文件

为了使代码可重现,我使用以下代码创建了输入文件:

writer = tf.io.TFRecordWriter(path='example.tfrecord')for record in data:    feature = {}    feature['X'] = tf.train.Feature(float_list=tf.train.FloatList(value=record[:42]))    feature['y'] = tf.train.Feature(float_list=tf.train.FloatList(value=record[42:64]))    feature['vad'] = tf.train.Feature(float_list=tf.train.FloatList(value=[record[64]]))    example = tf.train.Example(features=tf.train.Features(feature=feature))    serialized = example.SerializeToString()    writer.write(serialized)writer.close()

data是我们的训练数据,形状为[10000, 65]。我的example.tfrecord文件可以在这里找到。它大小为3 MB,实际上它会超过30 GB。

你可能会注意到,在链接的代码中,numpy数组的形状是[x, 87],而我的则是[x, 65]。这没问题——剩余部分在任何地方都没有使用。

使用tf.data.TFRecordDataset加载数据集

我想使用tf.data按需加载数据,并进行一些预取,没有必要将所有数据都保存在内存中。我的尝试如下:

我的数据集现在具有以下形状:

<MapDataset shapes: ((42,), ((22,), (1,))), types: (tf.float32, (tf.float32, tf.float32))>

我以为这是模型API所期望的(剧透:并不是)。

model.fit(dataset.batch(10))

会产生以下错误:

ValueError: Error when checking input: expected main_input to have 3 dimensions, but got array with shape (None, 42)

这是有道理的,我这里没有window参数。同时,看起来它并没有得到Model(inputs=main_input, outputs=[denoise_output, vad_output])所期望的正确形状。

如何修改load_dataset以符合tf.data所需的模型API格式?


回答:

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注