如何取消Tensorflow 2.0数据集的批处理

我有一个数据集,是通过以下代码使用tf.data.Dataset创建的:

dataset = Dataset.from_tensor_slices(corona_new)dataset = dataset.window(WINDOW_SIZE, 1, drop_remainder=True)dataset = dataset.flat_map(lambda x: x.batch(WINDOW_SIZE))dataset = dataset.map(lambda x: tf.transpose(x))for i in dataset:    print(i.numpy())    break

当我运行它时,得到的输出如下(这是一个批次的示例):

[[  0. 125. 111. 232. 164. 134. 235. 190.]  [  0.  14.  16.   7.   9.   7.   6.   8.] [  0. 132. 199. 158. 148. 141. 179. 174.] [  0.   0.   0.   2.   0.   2.   1.   2.] [  0.   0.   0.   0.   3.   5.   0.   0.]]

如何取消它们的批处理?


回答:

找到了我的解决方案。

在TensorFlow 2.0中,您可以通过调用.unbatch()函数来取消tf.data.Dataset的批处理。

示例:dataset.unbatch()

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注