我正在编写一个模型并进行预处理部分:我有一个方法,通过调用来预处理我的Tensorflow数据集:
ds = ds.map(process_path, num_parallel_calls=AUTOTUNE)
我按照Tensorflow的文档,得到了以下process_path的代码:
def process_path(filename): label = get_label(filename) image = tf.io.read_file(filename) image = tf.image.decode_jpeg(image, channels=3) image = tf.image.rgb_to_grayscale(image) image = tf.image.convert_image_dtype(image, tf.float32) image = tf.image.resize(image, [224, 224]) return image, label
然后我想添加我自己的预处理,例如旋转图像,所以我创建了一个用py_function包装的rotate方法,正如文档所建议的:
def rotate_image(image): return tfa.image.rotate(image, random.randrange(-5, 5)/1.0)def tf_rotate_image(image, label): [image,] = tf.py_function(rotate_image, [image], [tf.float32]) return image, label
然而,当我将此添加到我的process_path中时,模型似乎崩溃并冻结了…我在每次调整后添加了print语句以显示image.shape,发现旋转方法后图像形状变为<unknown>
,所以我认为这是错误的原因:
def process_path(filename): label = get_label(filename) image = tf.io.read_file(filename) print(image.shape) image = tf.image.decode_jpeg(image, channels=3) print(image.shape) image = tf.image.rgb_to_grayscale(image) print(image.shape) image = tf.image.convert_image_dtype(image, tf.float32) print(image.shape) image = tf.image.resize(image, [224, 224]) print(image.shape) image, label = tf_rotate_image(image, label) print(image.shape) return image, label
输出:
()(None, None, 3)(None, None, 1)(None, None, 1)(224, 224, 1)<unknown>
任何帮助将不胜感激。
回答:
您之所以丢失形状,是因为调用了tf.py_function
。
使用基于tensorflow操作的函数
在这种情况下,因为您使用的是tfa
,您可以直接调用该函数,而无需用tf.py_function
包装,因为您使用的是tf操作:
def process_path(filename): label = 1 image = tf.io.read_file(filename) print(image.shape) image = tf.image.decode_jpeg(image, channels=3) print(image.shape) image = tf.image.rgb_to_grayscale(image) print(image.shape) image = tf.image.convert_image_dtype(image, tf.float32) print(image.shape) image = tf.image.resize(image, [224, 224]) print(image.shape) image, label = (lambda x,y : (rotate_image(x), y))(image, label) print(image.shape)
将会得到以下结果:
>>> ds = ds.map(process_path)()(None, None, 3)(None, None, 1)(None, None, 1)(224, 224, 1)(224, 224, 1)
使用不包含tensorflow操作的函数
如果您想使用不包含tensorflow操作的函数,那么您可以使用tf.py_function
,并明确设置形状。这是在指南tf.data: Build TensorFlow input pipelines中所做的。以下是该指南中的示例:
def tf_random_rotate_image(image, label): im_shape = image.shape [image,] = tf.py_function(random_rotate_image, [image], [tf.float32]) # 形状是明确设置的,因为tensorflow无法确保 # 在函数执行期间形状不会被修改 image.set_shape(im_shape) return image, label
然而,进行此操作时,tensorflow做的一个假设是您设置的形状实际上是正确的!以下示例将崩溃,因为函数lambda x:1
不保留输入的形状。
def not_shape_preserving(image, label): im_shape = image.shape # 此函数不保留形状 [image,] = tf.py_function(lambda x: 1., [image], [tf.float32]) image.set_shape(im_shape) return image, label
创建数据集会成功,因为tensorflow相信您。然而,当尝试使用它时,您将遇到类似于以下错误:
Incompatible shapes at component 0: expected [224,224,1] but got [].