Tensorflow自定义预处理使用tf.py_function丢失形状

我正在编写一个模型并进行预处理部分:我有一个方法,通过调用来预处理我的Tensorflow数据集:

ds = ds.map(process_path, num_parallel_calls=AUTOTUNE)

我按照Tensorflow的文档,得到了以下process_path的代码:

def process_path(filename):  label = get_label(filename)  image = tf.io.read_file(filename)  image = tf.image.decode_jpeg(image, channels=3)  image = tf.image.rgb_to_grayscale(image)  image = tf.image.convert_image_dtype(image, tf.float32)  image = tf.image.resize(image, [224, 224])    return image, label

然后我想添加我自己的预处理,例如旋转图像,所以我创建了一个用py_function包装的rotate方法,正如文档所建议的:

def rotate_image(image):  return tfa.image.rotate(image, random.randrange(-5, 5)/1.0)def tf_rotate_image(image, label):  [image,] = tf.py_function(rotate_image, [image], [tf.float32])  return image, label

然而,当我将此添加到我的process_path中时,模型似乎崩溃并冻结了…我在每次调整后添加了print语句以显示image.shape,发现旋转方法后图像形状变为<unknown>,所以我认为这是错误的原因:

def process_path(filename):  label = get_label(filename)  image = tf.io.read_file(filename)  print(image.shape)  image = tf.image.decode_jpeg(image, channels=3)  print(image.shape)  image = tf.image.rgb_to_grayscale(image)  print(image.shape)  image = tf.image.convert_image_dtype(image, tf.float32)  print(image.shape)  image = tf.image.resize(image, [224, 224])  print(image.shape)  image, label = tf_rotate_image(image, label)  print(image.shape)    return image, label

输出:

()(None, None, 3)(None, None, 1)(None, None, 1)(224, 224, 1)<unknown>

任何帮助将不胜感激。


回答:

您之所以丢失形状,是因为调用了tf.py_function

使用基于tensorflow操作的函数

在这种情况下,因为您使用的是tfa,您可以直接调用该函数,而无需用tf.py_function包装,因为您使用的是tf操作:

def process_path(filename):  label = 1  image = tf.io.read_file(filename)  print(image.shape)  image = tf.image.decode_jpeg(image, channels=3)  print(image.shape)  image = tf.image.rgb_to_grayscale(image)  print(image.shape)  image = tf.image.convert_image_dtype(image, tf.float32)  print(image.shape)  image = tf.image.resize(image, [224, 224])  print(image.shape)  image, label = (lambda x,y : (rotate_image(x), y))(image, label)  print(image.shape)

将会得到以下结果:

>>> ds = ds.map(process_path)()(None, None, 3)(None, None, 1)(None, None, 1)(224, 224, 1)(224, 224, 1)

使用不包含tensorflow操作的函数

如果您想使用不包含tensorflow操作的函数,那么您可以使用tf.py_function,并明确设置形状。这是在指南tf.data: Build TensorFlow input pipelines中所做的。以下是该指南中的示例:

def tf_random_rotate_image(image, label):  im_shape = image.shape  [image,] = tf.py_function(random_rotate_image, [image], [tf.float32])  # 形状是明确设置的,因为tensorflow无法确保  # 在函数执行期间形状不会被修改  image.set_shape(im_shape)  return image, label

然而,进行此操作时,tensorflow做的一个假设是您设置的形状实际上是正确的!以下示例将崩溃,因为函数lambda x:1不保留输入的形状。

def not_shape_preserving(image, label):    im_shape = image.shape    # 此函数不保留形状    [image,] = tf.py_function(lambda x: 1., [image], [tf.float32])    image.set_shape(im_shape)    return image, label

创建数据集会成功,因为tensorflow相信您。然而,当尝试使用它时,您将遇到类似于以下错误:

Incompatible shapes at component 0: expected [224,224,1] but got [].

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注