我正在训练一个机器学习模型,用于检测图像中的各种用户界面元素,如文本框、图像、按钮等

为了创建模型,我正在遵循这个教程 https://www.geeksforgeeks.org/ml-training-image-classifier-using-tensorflow-object-detection-api/,但是在检查点3中的generate_tfrecord代码使用的是TensorFlow 1版本,而我使用的是2.2.0-rc0版本。在代码中,一些函数如tf.app在TensorFlow 2.2.0版本中已被移除,我希望根据新版本对代码进行修改。这段代码有助于生成用于训练和测试数据集的tf记录,以下是代码,我只想知道为了使其与TensorFlow 2.2.0兼容,我应该对这段代码做哪些修改

`

from __future__ import divisionfrom __future__ import print_functionfrom __future__ import absolute_importimport osimport ioimport pandas as pdimport tensorflow as tffrom PIL import Imagefrom object_detection.utils import dataset_utilfrom collections import namedtuple, OrderedDictflags = tf.app.flagsflags.DEFINE_string('csv_input', '', 'Path to the CSV input')flags.DEFINE_string('output_path', '', 'Path to output TFRecord')flags.DEFINE_string('image_dir', '', 'Path to images')FLAGS = flags.FLAGS# TO-DO replace this with label mapdef class_text_to_int(row_label):    if row_label == 'Button':        return 1    if row_label == 'Text Box':        return 2    if row_label == 'Check Box':        return 3    if row_label == 'Link':        return 4    if row_label == 'Hyperlink':        return 5    if row_label == 'Icon':        return 6    if row_label == 'Text':        return 7    if row_label == 'Image':        return 8    else:        Nonedef split(df, group):    data = namedtuple('data', ['filename', 'object'])    gb = df.groupby(group)    return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]def create_tf_example(group, path):    with tf.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:        encoded_jpg = fid.read()    encoded_jpg_io = io.BytesIO(encoded_jpg)    image = Image.open(encoded_jpg_io)    width, height = image.size    filename = group.filename.encode('utf8')    image_format = b'jpg'    xmins = []    xmaxs = []    ymins = []    ymaxs = []    classes_text = []    classes = []    for index, row in group.object.iterrows():        xmins.append(row['xmin'] / width)        xmaxs.append(row['xmax'] / width)        ymins.append(row['ymin'] / height)        ymaxs.append(row['ymax'] / height)        classes_text.append(row['class'].encode('utf8'))        classes.append(class_text_to_int(row['class']))    tf_example = tf.train.Example(features=tf.train.Features(feature={        'image/height': dataset_util.int64_feature(height),        'image/width': dataset_util.int64_feature(width),        'image/filename': dataset_util.bytes_feature(filename),        'image/source_id': dataset_util.bytes_feature(filename),        'image/encoded': dataset_util.bytes_feature(encoded_jpg),        'image/format': dataset_util.bytes_feature(image_format),        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),        'image/object/class/label': dataset_util.int64_list_feature(classes),    }))    return tf_exampledef main(_):    writer = tf.python_io.TFRecordWriter(FLAGS.output_path)    path = os.path.join(FLAGS.image_dir)    examples = pd.read_csv(FLAGS.csv_input)    grouped = split(examples, 'filename')    for group in grouped:        tf_example = create_tf_example(group, path)        writer.write(tf_example.SerializeToString())    writer.close()    output_path = os.path.join(os.getcwd(), FLAGS.output_path)    print('Successfully created the TFRecords: {}'.format(output_path))if __name__ == '__main__':    tf.app.run()

`


回答:

请参考以下文档以迁移或升级您的代码,使其在TensorFlow 2.x rc中工作。

这里是迁移代码的指南。

这里是从1.x到2.x版本通过命令升级代码的文档。

我在这里使用了谷歌协作平台,使用了升级脚本。

这是您升级后的代码:

from __future__ import divisionfrom __future__ import print_functionfrom __future__ import absolute_importimport osimport ioimport pandas as pdimport tensorflow as tffrom PIL import Imagefrom object_detection.utils import dataset_utilfrom collections import namedtuple, OrderedDictflags = tf.compat.v1.app.flagsflags.DEFINE_string('csv_input', '', 'Path to the CSV input')flags.DEFINE_string('output_path', '', 'Path to output TFRecord')flags.DEFINE_string('image_dir', '', 'Path to images')FLAGS = flags.FLAGS# TO-DO replace this with label mapdef class_text_to_int(row_label):    if row_label == 'Button':        return 1    if row_label == 'Text Box':        return 2    if row_label == 'Check Box':        return 3    if row_label == 'Link':        return 4    if row_label == 'Hyperlink':        return 5    if row_label == 'Icon':        return 6    if row_label == 'Text':        return 7    if row_label == 'Image':        return 8    else:        Nonedef split(df, group):    data = namedtuple('data', ['filename', 'object'])    gb = df.groupby(group)    return [data(filename, gb.get_group(x)) for filename, x in zip(gb.groups.keys(), gb.groups)]def create_tf_example(group, path):    with tf.io.gfile.GFile(os.path.join(path, '{}'.format(group.filename)), 'rb') as fid:        encoded_jpg = fid.read()    encoded_jpg_io = io.BytesIO(encoded_jpg)    image = Image.open(encoded_jpg_io)    width, height = image.size    filename = group.filename.encode('utf8')    image_format = b'jpg'    xmins = []    xmaxs = []    ymins = []    ymaxs = []    classes_text = []    classes = []    for index, row in group.object.iterrows():        xmins.append(row['xmin'] / width)        xmaxs.append(row['xmax'] / width)        ymins.append(row['ymin'] / height)        ymaxs.append(row['ymax'] / height)        classes_text.append(row['class'].encode('utf8'))        classes.append(class_text_to_int(row['class']))    tf_example = tf.train.Example(features=tf.train.Features(feature={        'image/height': dataset_util.int64_feature(height),        'image/width': dataset_util.int64_feature(width),        'image/filename': dataset_util.bytes_feature(filename),        'image/source_id': dataset_util.bytes_feature(filename),        'image/encoded': dataset_util.bytes_feature(encoded_jpg),        'image/format': dataset_util.bytes_feature(image_format),        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),        'image/object/class/label': dataset_util.int64_list_feature(classes),    }))    return tf_exampledef main(_):    writer = tf.io.TFRecordWriter(FLAGS.output_path)    path = os.path.join(FLAGS.image_dir)    examples = pd.read_csv(FLAGS.csv_input)    grouped = split(examples, 'filename')    for group in grouped:        tf_example = create_tf_example(group, path)        writer.write(tf_example.SerializeToString())    writer.close()    output_path = os.path.join(os.getcwd(), FLAGS.output_path)    print('Successfully created the TFRecords: {}'.format(output_path))if __name__ == '__main__':    tf.compat.v1.app.run() 

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注