我使用了一个类似于这个的脚本将我的数据集转换为分片的tfrecords。但是当我尝试使用下面的脚本读取它时,tensorflow会冻结,我不得不使用kill命令终止进程。(注意:目前我正在CPU模式下工作)
def parse_example_proto(example_serialized): feature_map = { 'image/encoded': tf.FixedLenFeature([], dtype=tf.string, default_value=''), 'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64, default_value=-1), 'image/class/text': tf.FixedLenFeature([], dtype=tf.string, default_value=''), } features = tf.parse_single_example(example_serialized, feature_map) init_image = tf.image.decode_jpeg(features['image/encoded'], channels = 3) init_image.set_shape([800,480,3]) image = tf.reshape(init_image,tf.pack([800, 480, 3])) float_image = tf.image.convert_image_dtype(image, dtype=tf.float32) label = tf.cast(features['image/class/label'], dtype=tf.int32) return float_image , label, features['image/class/text']def batch_inputs(batch_size, train,sess, num_preprocess_threads=4, num_readers=1): with tf.name_scope('batch_processing'): tf_record_pattern = os.path.join('/home/raarora/', '%s-*' % 'train') data_files = tf.gfile.Glob(tf_record_pattern) if data_files is None: raise ValueError('No data files found for this dataset')# print data_files # Create filename_queue if train: filename_queue = tf.train.string_input_producer(data_files, shuffle=True, capacity=8) else: filename_queue = tf.train.string_input_producer(data_files, shuffle=False, capacity=1) reader =tf.TFRecordReader() _, example_serialized = reader.read(filename_queue) image, label, _ = parse_example_proto(example_serialized) examples_per_shard = 201 min_queue_examples = examples_per_shard * 2 images, labels = tf.train.shuffle_batch( [image, label], batch_size=batch_size, num_threads=4, capacity=min_queue_examples + 3 * batch_size, min_after_dequeue=min_queue_examples) print images.eval(session=sess) return s,images,labelsif __name__ == '__main__': sess = tf.Session() coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess=sess, coord=coord) s,_,_ = batch_inputs(2,1,sess)
回答:
我已经解决了这个问题。我原本以为TFRecord是一种类似字典的数据结构,你只需要提供所需的键,但当我提供完整的特征映射并对图像处理稍作修改后,它就工作了。
我犯的另一个错误是,queue_runner应该在调用tf.train.shuffle_batch()之后启动。我不知道这是个bug还是我理解上的差距
以下是读取数据的有效代码
def getImage(filename): # 将文件名转换为输入管道的队列。 filenameQ = tf.train.string_input_producer([filename],num_epochs=None) # 读取记录的对象 recordReader = tf.TFRecordReader() # 读取单个示例的完整特征集 key, fullExample = recordReader.read(filenameQ) # 将完整示例解析为其组件特征。 features = tf.parse_single_example( fullExample, features={ 'image/height': tf.FixedLenFeature([], tf.int64), 'image/width': tf.FixedLenFeature([], tf.int64), 'image/colorspace': tf.FixedLenFeature([], dtype=tf.string,default_value=''), 'image/channels': tf.FixedLenFeature([], tf.int64), 'image/class/label': tf.FixedLenFeature([],tf.int64), 'image/class/text': tf.FixedLenFeature([], dtype=tf.string,default_value=''), 'image/format': tf.FixedLenFeature([], dtype=tf.string,default_value=''), 'image/filename': tf.FixedLenFeature([], dtype=tf.string,default_value=''), 'image/encoded': tf.FixedLenFeature([], dtype=tf.string, default_value='') }) # 现在我们将操作标签和图像特征 label = features['image/class/label'] image_buffer = features['image/encoded'] # 解码jpeg with tf.name_scope('decode_jpeg',[image_buffer], None): # 解码 image = tf.image.decode_jpeg(image_buffer, channels=3) # 并转换为单精度数据类型 image = tf.image.convert_image_dtype(image, dtype=tf.float32) # 将图像转换为单一数组,其中每个元素对应一个像素的灰度值。 # "1-.."部分将图像反转,使背景变为黑色。 # 重新定义标签为"one-hot"向量 # 这里它将是[0,1]或[1,0]。 # 这种方法可以轻松扩展到更多类别。 image=tf.reshape(image,[height,width,3]) label=tf.pack(tf.one_hot(label-1, nClass)) return label, imagelabel, image = getImage("train-00000-of-00001")imageBatch, labelBatch = tf.train.shuffle_batch( [image, label], batch_size=2, capacity=20, min_after_dequeue=10)sess = tf.InteractiveSession()sess.run(tf.initialize_all_variables())# 启动用于读取文件的线程coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess,coord=coord)batch_xs, batch_ys = sess.run([imageBatch, labelBatch])print batch_xscoord.request_stop()coord.join(threads)
注意:我对分片记录不是很清楚,所以我只使用了一个分片。
致谢给 https://agray3.github.io/2016/11/29/Demystifying-Data-Input-to-TensorFlow-for-Deep-Learning.html