我想从TFRecord中读取数据

我将图像数据保存到了tfrecord中,但是我无法使用tensorflow dataset api解析它。

我的环境

  • Ubuntu 18.04
  • Python 3.6.8
  • Jupyter Notebook
  • Tensorflow 1.12.0

我使用以下代码保存了图像数据,

writer = tf.python_io.TFRecordWriter('training.tfrecord')# X_train: 图像路径, y_train: 标签 (0 或 1)for image_path, label in zip(X_train, y_train):    image = cv2.imread(image_path)    image = cv2.resize(image, (150, 150)) / 255.0    ex = tf.train.Example(        features = tf.train.Features(            feature={                'image' : tf.train.Feature(float_list = tf.train.FloatList(value=image.ravel())),                'label' : tf.train.Feature(int64_list = tf.train.Int64List(value=[label]))            }        )    )    writer.write(ex.SerializeToString())writer.close()

我尝试从tfrecord文件中获取图像,方法如下。

for record in tf.python_io.tf_record_iterator('test.tfrecord'):    example = tf.train.Example()    example.ParseFromString(record)    img = example.features.feature['image'].float_list.value    label = example.features.feature['label'].int64_list.value[0]

这种方法是可行的。

enter image description here

但是,当我使用Dataset API为我的机器学习模型获取图像时,它就不行了。

def _parse_function(example_proto):    features = {        'label' : tf.FixedLenFeature((), tf.int64),        'image' : tf.FixedLenFeature((), tf.float32)    }    parsed_features = tf.parse_single_example(example_proto, features)    return parsed_features['image'], parsed_features['label']def read_image(images, labels):    label = tf.cast(labels, tf.int32)    images = tf.cast(images, tf.float32)    image = tf.reshape(images, [150, 150, 3])# 读取数据dataset = tf.data.TFRecordDataset('training.tfrecord')dataset = dataset.map(_parse_function)dataset = dataset.map(read_image) # <- ERROR!

错误信息是

ValueError: 无法将具有1个元素的张量重塑为形状[150,150,3](67500个元素)用于'Reshape'(操作:'Reshape')的输入形状:[],[3],并且输入张量计算为部分形状:input[1] = [150,150,3]。

我认为这个错误的原因是数组的形状不对,所以我确认了“dataset”的元素

<MapDataset shapes: ((), ()), types: (tf.float32, tf.int64)>

“dataset”变量没有数据。我不知道为什么会这样。

后记

我尝试了Sharky的解决方案,结果是,

def parse(example_proto):    features = {        'label' : tf.FixedLenFeature((), tf.string, ''),        'image' : tf.FixedLenFeature((), tf.string, '')    }    parsed_features = tf.parse_single_example(example_proto, features)    img_shape = tf.stack([150, 150, 3])    image = tf.decode_raw(parsed_features['image'], tf.float32)    image = tf.reshape(image, img_shape)    label = tf.decode_raw(parsed_features['label'], tf.int32)    label = tf.reshape(label, tf.stack([1]))    return image, label

似乎是可行的。但是我无法从这个MapDataset类型对象中获取数组。如何做到这一点?


回答:

尝试使用单个解析函数

def parse(example_proto):    features = {        'label' : tf.FixedLenFeature((), tf.int64),        'image' : tf.FixedLenFeature((), tf.string)    }    parsed_features = tf.parse_single_example(example_proto, features)    img_shape = tf.stack([height, width, channel])    image = tf.decode_raw(parsed_features['image'], tf.float32)    image = tf.reshape(image, img_shape)    label = tf.cast(parsed_features['label'], tf.int32)    return image, label

好的,看起来parse_single_example期望的是字符串类型而不是浮点数。我建议像这样编码

def int64_feature(value):    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))def bytes_feature(value):    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))writer = tf.python_io.TFRecordWriter('training.tfrecord')for image_path, label in zip(X_train, y_train):    image = cv2.imread(image_path)    image = cv2.resize(image, (150, 150)) / 255.0    img_raw = image.tostring()    ex = tf.train.Example(features=tf.train.Features(feature={                                                                                             'image': bytes_feature(img_raw),                        'label': int64_feature(label)                         }))    writer.write(ex.SerializeToString())writer.close()

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注