TF slice_input_producer 无法保持张量同步

我在我的 TF 网络中读取图像,但我也需要与之关联的标签。

所以我尝试按照这个回答来做,但输出的标签实际上与我在每个批次中获取的图像不匹配。

我的图像名称格式为dir/3.jpg,所以我只是从图像文件名中提取标签。

truth_filenames_np = ...truth_filenames_tf = tf.convert_to_tensor(truth_filenames_np)# 获取标签labels = [f.rsplit("/", 1)[1] for f in truth_filenames_np]labels_tf = tf.convert_to_tensor(labels)# *** 这一行应该确保两个输入张量同步(根据我的有限理解)# 我的列表也已经打乱,所以我设置了shuffle=Falsetruth_image_name, truth_label = tf.train.slice_input_producer([truth_filenames_tf, labels_tf], shuffle=False)truth_image_value = tf.read_file(truth_image_name)truth_image = tf.image.decode_jpeg(truth_image_value)truth_image.set_shape([IMAGE_DIM, IMAGE_DIM, 3])truth_image = tf.cast(truth_image, tf.float32)truth_image = truth_image/255.0# 另一个关键步骤,我将它们批处理在一起truth_images_batch, truth_label_batch = tf.train.batch([truth_image, truth_label], batch_size=mb_size)with tf.Session() as sess:    sess.run(tf.global_variables_initializer())    coord = tf.train.Coordinator()    threads = tf.train.start_queue_runners(coord=coord)    for i in range(epochs):        print "Epoch ", i        X_truth_batch = truth_images_batch.eval()        X_label_batch = truth_label_batch.eval()        # 在这里我显示这个批次中的所有图像,然后我检查它们实际上是哪些文件号。        # 但是,显示的图像与X_label_batch打印的内容不对应!        print X_label_batch        plot_batch(X_truth_batch)    coord.request_stop()    coord.join(threads)

我做错了什么,还是slice_input_producer实际上并不能确保其输入张量同步?

补充说明:

我还注意到,当我从tf.train.batch获取一个批次时,批次中的元素在原始列表中是相邻的,但批次的顺序并不是原始顺序。例如:如果我的数据是[“dir/1.jpg”, “dir/2.jpg”, “dir/3.jpg”, “dir/4.jpg”, “dir/5.jpg, “dir/6.jpg”],那么我可能会得到批次(batch_size=2)[“dir/3.jpg”, “dir/4.jpg”],然后是批次[“dir/1.jpg”, “dir/2.jpg”],最后是最后一个批次。所以这使得即使只是使用FIFO队列来处理标签也变得困难,因为顺序与批次顺序不匹配。


回答:

这是一个可以重现问题的完整可运行示例:

import tensorflow as tftruth_filenames_np = ['dir/%d.jpg' % j for j in range(66)]truth_filenames_tf = tf.convert_to_tensor(truth_filenames_np)# 获取标签labels = [f.rsplit("/", 1)[1] for f in truth_filenames_np]labels_tf = tf.convert_to_tensor(labels)# 我的列表也已经打乱,所以我设置了shuffle=Falsetruth_image_name, truth_label = tf.train.slice_input_producer(    [truth_filenames_tf, labels_tf], shuffle=False)# # 另一个关键步骤,我将它们批处理在一起# truth_images_batch, truth_label_batch = tf.train.batch(#     [truth_image_name, truth_label], batch_size=11)epochs = 7with tf.Session() as sess:    sess.run(tf.global_variables_initializer())    coord = tf.train.Coordinator()    threads = tf.train.start_queue_runners(coord=coord)    for i in range(epochs):        print("Epoch ", i)        X_truth_batch = truth_image_name.eval()        X_label_batch = truth_label.eval()        # 在这里我显示这个批次中的所有图像,然后我检查        # 它们实际上是哪些文件号。        # 但是,显示的图像与X_label_batch打印的内容不对应!        print(X_truth_batch)        print(X_label_batch)    coord.request_stop()    coord.join(threads)

这打印的内容是:

Epoch  0b'dir/0.jpg'b'1.jpg'Epoch  1b'dir/2.jpg'b'3.jpg'Epoch  2b'dir/4.jpg'b'5.jpg'Epoch  3b'dir/6.jpg'b'7.jpg'Epoch  4b'dir/8.jpg'b'9.jpg'Epoch  5b'dir/10.jpg'b'11.jpg'Epoch  6b'dir/12.jpg'b'13.jpg'

所以基本上每次eval调用都会再次运行操作!添加批处理对这一点没有影响 – 只是打印批次(前11个文件名后面跟着下11个标签,依此类推)

我看到的解决方法是:

for i in range(epochs):    print("Epoch ", i)    pair = tf.convert_to_tensor([truth_image_name, truth_label]).eval()    print(pair[0])    print(pair[1])

这正确地打印了:

Epoch  0b'dir/0.jpg'b'0.jpg'Epoch  1b'dir/1.jpg'b'1.jpg'# ...

但对违反最少惊讶原则没有任何帮助。

编辑:还有另一种方法:

import tensorflow as tftruth_filenames_np = ['dir/%d.jpg' % j for j in range(66)]truth_filenames_tf = tf.convert_to_tensor(truth_filenames_np)labels = [f.rsplit("/", 1)[1] for f in truth_filenames_np]labels_tf = tf.convert_to_tensor(labels)truth_image_name, truth_label = tf.train.slice_input_producer(    [truth_filenames_tf, labels_tf], shuffle=False)epochs = 7with tf.Session() as sess:    sess.run(tf.global_variables_initializer())    tf.train.start_queue_runners(sess=sess)    for i in range(epochs):        print("Epoch ", i)        X_truth_batch, X_label_batch = sess.run(            [truth_image_name, truth_label])        print(X_truth_batch)        print(X_label_batch)

这是一个更好的方法,因为tf.convert_to_tensor和相关函数只接受相同类型/形状的张量等。

请注意,为了简化,我删除了协调器,但这会导致一个警告:

W c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\kernels\queue_base.cc:294] _0_input_producer/input_producer/fraction_of_32_full/fraction_of_32_full: Skipping cancelled enqueue attempt with queue not closed

参见这里

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注