Tensorflow将数据集拆分为训练和测试集导致瓶颈/速度慢

我有一个数据集,当我使用ds = ds.map(process_path, num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)对其进行预处理时,这行代码执行得非常快。然而,当我尝试访问数据集中的一个元素时:

for image, label in ds.take(1):  print(image.shape)  image = tf.squeeze(image)  plt.imshow(image, cmap='gray')

加载需要一两秒的时间;这是我的第一个问题:

预处理是否只有在访问数据集中的元素时才运行,而不是在我调用ds.map(process_path,...)时立即运行?

然而,我主要的问题是,当我将数据集ds拆分为训练和测试两部分,然后再次尝试访问其中一个元素时,速度明显变慢… 慢了大约20倍。我是用以下方式将其拆分为两部分的:

test_ds_size = int(image_count * 0.2)train_ds = ds.skip(test_ds_size)test_ds = ds.take(test_ds_size)

然后我尝试以与上面相同的方式访问它,只是将ds替换为train_ds;我的第二个问题是:

为什么仅仅因为拆分为两部分,速度就明显变慢了?

还是我做错了什么…

任何帮助都非常感激。


回答:

dataset.map通过应用映射函数创建一个新数据集。
即使在循环中,当你执行dataset.take()时,它也会在非常短的时间内从指定的数量中创建一个新数据集。
在数据集加载后,你执行的其他操作与tf.data的性能无关。
你可以从下面的例子中检查。

import tensorflow as tffrom time import timedataset = tf.data.Dataset.range(1, 100)t1 = time()dataset = dataset.map(lambda x: x + 1)t2 = time()print("Time taken for map : ", t2-t1)t3 = time()ds = dataset.take(50)t4 = time()list(ds.as_numpy_iterator())print("Time taken for take() : ",t4-t3) Time taken for map :  0.013489961624145508Time taken for take() :  0.0005645751953125

现在,让我们看看在一些操作之后take()所花费的时间。

dataset = tf.data.Dataset.range(1, 100)t1 = time()dataset = dataset.map(lambda x: x + 1)t2 = time()print("Time taken for map : ", t2-t1)t3 = time()ds = dataset.take(50)list(ds.as_numpy_iterator())t4 = time()print("Time taken for take() after some operation : ",t4-t3)Time taken for map :  0.00974416732788086Time taken for take() after some operation :  0.017722606658935547 

关于从现有数据集中拆分训练和测试数据,你指定的方式是可以的,但这需要时间,因为它会遍历所有元素。

创建tf.data.Dataset用于训练和测试的理想方式是像这里显示的那样分别创建。确保在创建之前打乱数据,以便在训练和测试数据中正确分配你的数据集分布。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注