在TensorFlow中将队列附加到numpy数组以替代文件读取数据?

我已经阅读了TensorFlow上的CNN教程,并试图在我的项目中使用相同的模型。现在的问题在于数据读取。我有大约25000张用于训练的图片,以及大约5000张用于测试和验证的图片。这些文件是png格式,我可以读取它们并转换为numpy.ndarray。

教程中的CNN示例使用队列从提供的文件列表中获取记录。我尝试通过将我的图片重塑为1-D数组,并在前面附加一个标签值来创建我自己的这种二进制文件。所以我的数据看起来像这样

[[1,12,34,24,53,...,105,234,102], [12,112,43,24,52,...,115,244,98],....]

上述数组的单行长度为22501,其中第一个元素是标签。

我使用pickle将文件转储,并尝试使用tf.FixedLengthRecordReader从文件中读取,如示例中演示的那样

我正在做与cifar10_input.py中给出的相同的事情来读取二进制文件并将它们放入记录对象中。

现在,当我从文件中读取时,标签和图像值是不同的。我能理解原因是pickle在二进制文件中也转储了括号和方括号的额外信息,这改变了固定长度记录的大小。

上述示例使用文件名并将其传递给队列以获取文件,然后队列从文件中读取单个记录。

我想知道是否可以将上述定义的numpy数组而不是文件名传递给某个读取器,它可以从该数组中一次获取一个记录,而不是从文件中获取。


回答:

让你的数据与CNN示例代码一起工作的最简单方法可能是创建read_cifar10()的修改版本并使用它来代替:

  1. 写出一个包含你的numpy数组内容的二进制文件。

    import numpy as npimages_and_labels_array = np.array([[...], ...],  # [[1,12,34,24,53,...,102],                                                  #  [12,112,43,24,52,...,98],                                                  #  ...]                                   dtype=np.uint8)images_and_labels_array.tofile("/tmp/images.bin")

    这个文件类似于CIFAR10数据文件使用的格式。你可能想要生成多个文件以获得读取并行性。请注意,ndarray.tofile()以行主序写入二进制数据,没有其他元数据;对数组进行pickle处理会添加TensorFlow的解析例程无法理解的Python特定元数据。

  2. 编写一个处理你的记录格式的read_cifar10()的修改版本。

    def read_my_data(filename_queue):  class ImageRecord(object):    pass  result = ImageRecord()  # 数据集中图像的维度。  label_bytes = 1  # 根据需要设置以下常量。  result.height = IMAGE_HEIGHT  result.width = IMAGE_WIDTH  result.depth = IMAGE_DEPTH  image_bytes = result.height * result.width * result.depth  # 每个记录由标签和图像组成,每个都有固定数量的字节。  record_bytes = label_bytes + image_bytes  assert record_bytes == 22501  # 基于你的问题。  # 从filename_queue中获取文件名读取记录。  二进制文件中没有头部或尾部,因此我们将header_bytes和footer_bytes保持在默认值0。  reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)  result.key, value = reader.read(filename_queue)  # 将字符串转换为长度为record_bytes的uint8向量。  record_bytes = tf.decode_raw(value, tf.uint8)  # 前几个字节代表标签,我们将其从uint8转换为int32。  result.label = tf.cast(      tf.slice(record_bytes, [0], [label_bytes]), tf.int32)  # 标签后的剩余字节代表图像,我们将其从[depth * height * width]重塑为[depth, height, width]。  depth_major = tf.reshape(tf.slice(record_bytes, [label_bytes], [image_bytes]),                           [result.depth, result.height, result.width])  # 从[depth, height, width]转换为[height, width, depth]。  result.uint8image = tf.transpose(depth_major, [1, 2, 0])  return result
  3. 修改distorted_inputs()以使用你的新数据集:

    def distorted_inputs(data_dir, batch_size):  """[...]"""  filenames = ["/tmp/images.bin"]  # 或者,如果你在第一步中生成了多个文件,则为文件名列表。  for f in filenames:    if not gfile.Exists(f):      raise ValueError('未找到文件: ' + f)  # 创建一个生成要读取的文件名的队列。  filename_queue = tf.train.string_input_producer(filenames)  # 从filename_queue中的文件读取示例。  read_input = read_my_data(filename_queue)  reshaped_image = tf.cast(read_input.uint8image, tf.float32)  # [...] (根据你的问题,可能需要在这里修改其他参数。)

鉴于你的起点,这是一组最小的步骤。使用TensorFlow操作进行PNG解码可能会更有效,但这将是一个更大的更改。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注