为了创建模型,我正在遵循这个教程 https://www.geeksforgeeks.org/ml-training-image-classifier-using-tensorflow-object-detection-api/,但是在检查点3中的generate_tfrecord代码使用的是TensorFlow 1版本,而我使用的是2.2.0-rc0版本。在代码中,一些函数如tf.app在TensorFlow 2.2.0版本中已被移除,我希望根据新版本对代码进行修改。这段代码有助于生成用于训练和测试数据集的tf记录,以下是代码,我只想知道为了使其与TensorFlow 2.2.0兼容,我应该对这段代码做哪些修改
`
from __future__ import divisionfrom __future__ import print_functionfrom __future__ import absolute_importimport osimport ioimport pandas as pdimport tensorflow as tffrom PIL import Imagefrom object_detection.utils import dataset_utilfrom collections import namedtuple, OrderedDictflags = tf.app.flagsflags.DEFINE_string('csv_input', '', 'Path to the CSV input')flags.DEFINE_string('output_path', '', 'Path to output TFRecord')flags.DEFINE_string('image_dir', '', 'Path to images')FLAGS = flags.FLAGS# TO-DO replace this with label mapdef class_text_to_int(row_label): if row_label == 'Button': return 1 if row_label == 'Text Box': return 2 if row_label == 'Check Box': return 3 if row_label == 'Link': return 4 if row_label == 'Hyperlink': return 5 if row_label == 'Icon': return 6 if row_label == 'Text': return 7 if row_label == 'Image': return 8 else: Nonedef split(df, group): data = namedtuple('data', ['filename', 'object']) gb = df.groupby(group) return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]def create_tf_example(group, path): with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid: encoded_jpg = fid.read() encoded_jpg_io = io.BytesIO(encoded_jpg) image = Image.open(encoded_jpg_io) width, height = image.size filename = group.filename.encode('utf8') image_format = b'jpg' xmins = [] xmaxs = [] ymins = [] ymaxs = [] classes_text = [] classes = [] for index, row in group.object.iterrows(): xmins.append(row['xmin'] / width) xmaxs.append(row['xmax'] / width) ymins.append(row['ymin'] / height) ymaxs.append(row['ymax'] / height) classes_text.append(row['class'].encode('utf8')) classes.append(class_text_to_int(row['class'])) tf_example = tf.train.Example(features=tf.train.Features(feature={ 'image/height': dataset_util.int64_feature(height), 'image/width': dataset_util.int64_feature(width), 'image/filename': dataset_util.bytes_feature(filename), 'image/source_id': dataset_util.bytes_feature(filename), 'image/encoded': dataset_util.bytes_feature(encoded_jpg), 'image/format': dataset_util.bytes_feature(image_format), 'image/object/bbox/xmin': dataset_util.float_list_feature(xmins), 'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs), 'image/object/bbox/ymin': dataset_util.float_list_feature(ymins), 'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs), 'image/object/class/text': dataset_util.bytes_list_feature(classes_text), 'image/object/class/label': dataset_util.int64_list_feature(classes), })) return tf_exampledef main(_): writer = tf.python_io.TFRecordWriter(FLAGS.output_path) path = os.path.join(FLAGS.image_dir) examples = pd.read_csv(FLAGS.csv_input) grouped = split(examples, 'filename') for group in grouped: tf_example = create_tf_example(group, path) writer.write(tf_example.SerializeToString()) writer.close() output_path = os.path.join(os.getcwd(), FLAGS.output_path) print('Successfully created the TFRecords: {}'.format(output_path))if __name__ == '__main__': tf.app.run()
`
回答:
请参考以下文档以迁移或升级您的代码,使其在TensorFlow 2.x rc中工作。
这里是迁移代码的指南。
这里是从1.x到2.x版本通过命令升级代码的文档。
我在这里使用了谷歌协作平台,使用了升级脚本。
这是您升级后的代码:
from __future__ import divisionfrom __future__ import print_functionfrom __future__ import absolute_importimport osimport ioimport pandas as pdimport tensorflow as tffrom PIL import Imagefrom object_detection.utils import dataset_utilfrom collections import namedtuple, OrderedDictflags = tf.compat.v1.app.flagsflags.DEFINE_string('csv_input', '', 'Path to the CSV input')flags.DEFINE_string('output_path', '', 'Path to output TFRecord')flags.DEFINE_string('image_dir', '', 'Path to images')FLAGS = flags.FLAGS# TO-DO replace this with label mapdef class_text_to_int(row_label): if row_label == 'Button': return 1 if row_label == 'Text Box': return 2 if row_label == 'Check Box': return 3 if row_label == 'Link': return 4 if row_label == 'Hyperlink': return 5 if row_label == 'Icon': return 6 if row_label == 'Text': return 7 if row_label == 'Image': return 8 else: Nonedef split(df, group): data = namedtuple('data', ['filename', 'object']) gb = df.groupby(group) return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]def create_tf_example(group, path): with tf.io.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid: encoded_jpg = fid.read() encoded_jpg_io = io.BytesIO(encoded_jpg) image = Image.open(encoded_jpg_io) width, height = image.size filename = group.filename.encode('utf8') image_format = b'jpg' xmins = [] xmaxs = [] ymins = [] ymaxs = [] classes_text = [] classes = [] for index, row in group.object.iterrows(): xmins.append(row['xmin'] / width) xmaxs.append(row['xmax'] / width) ymins.append(row['ymin'] / height) ymaxs.append(row['ymax'] / height) classes_text.append(row['class'].encode('utf8')) classes.append(class_text_to_int(row['class'])) tf_example = tf.train.Example(features=tf.train.Features(feature={ 'image/height': dataset_util.int64_feature(height), 'image/width': dataset_util.int64_feature(width), 'image/filename': dataset_util.bytes_feature(filename), 'image/source_id': dataset_util.bytes_feature(filename), 'image/encoded': dataset_util.bytes_feature(encoded_jpg), 'image/format': dataset_util.bytes_feature(image_format), 'image/object/bbox/xmin': dataset_util.float_list_feature(xmins), 'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs), 'image/object/bbox/ymin': dataset_util.float_list_feature(ymins), 'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs), 'image/object/class/text': dataset_util.bytes_list_feature(classes_text), 'image/object/class/label': dataset_util.int64_list_feature(classes), })) return tf_exampledef main(_): writer = tf.io.TFRecordWriter(FLAGS.output_path) path = os.path.join(FLAGS.image_dir) examples = pd.read_csv(FLAGS.csv_input) grouped = split(examples, 'filename') for group in grouped: tf_example = create_tf_example(group, path) writer.write(tf_example.SerializeToString()) writer.close() output_path = os.path.join(os.getcwd(), FLAGS.output_path) print('Successfully created the TFRecords: {}'.format(output_path))if __name__ == '__main__': tf.compat.v1.app.run()