我想从我的原始图像中提取补丁作为 CNN 的输入。经过一些研究,我找到了使用tensorflow.compat.v1.extract_image_patches
提取补丁的方法。
由于这些补丁需要重塑为“图像格式”,我实现了一个名为 reshape_image_patches 的方法来重塑它们,并将重塑后的补丁存储在一个数组中。
image_patches2 = []def reshape_image_patches(image_patches, sess, ksize_rows, ksize_cols): a = sess.run(tf.shape(image_patches)) nr, nc = a[1], a[2] for i in range(nr): for j in range(nc): patch = tf.reshape(image_patches[0,i,j,], [ksize_rows, ksize_cols, 3]) image_patches2.append(patch) return image_patches2
如何将此方法与 Keras 生成器结合使用,使这些补丁成为我的 CNN 的输入?
编辑 1:
我尝试了加载 TensorFlow 图像并创建补丁中的方法
import tensorflow as tfimport matplotlib.pyplot as pltimport numpy as npdataset = tf.keras.preprocessing.image_dataset_from_directory( <directory>, label_mode=None, seed=1, subset='training', validation_split=0.1, image_size=(900, 900))get_patches = lambda x: (tf.reshape( tf.image.extract_patches( x, sizes=[1, 16, 16, 1], strides=[1, 8, 8, 1], rates=[1, 1, 1, 1], padding='VALID'), (111*111, 16, 16, 3)))dataset = dataset.map(get_patches)fig = plt.figure()plt.subplots_adjust(wspace=.1, hspace=.2)images = next(iter(dataset))for index, image in enumerate(images): ax = plt.subplot(2, 2, index + 1) ax.set_xticks([]) ax.set_yticks([]) ax.imshow(image)plt.show()
在这一行: images = next(iter(dataset)) 我遇到了错误: InvalidArgumentError: 输入到 reshape 的是一个包含 302800896 个值的张量,但请求的形状只有 9462528[[{{node Reshape}}]]
有人知道如何修复这个问题吗?
回答:
tf.reshape
不会改变张量中元素的顺序或元素的总数。如错误所述,你试图将总元素数从 302800896 减少到 9462528。你在lambda
函数中使用了tf.reshape
。
在下面的示例中,我重新创建了你的场景,我为tf.reshape
的shape
参数设置为2
,这不能容纳原始张量的所有元素,因此抛出了错误 –
代码 –
%tensorflow_version 2.ximport tensorflow as tft1 = tf.Variable([1,2,2,4,5,6])t2 = tf.reshape(t1, 2)
输出 –
---------------------------------------------------------------------------InvalidArgumentError Traceback (most recent call last)<ipython-input-3-0ff1d701ff22> in <module>() 3 t1 = tf.Variable([1,2,2,4,5,6]) 4 ----> 5 t2 = tf.reshape(t1, 2)3 frames/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)InvalidArgumentError: Input to reshape is a tensor with 6 values, but the requested shape has 2 [Op:Reshape]
tf.reshape
应该以这样的方式进行,使元素的排列可以改变,但总元素数必须保持不变。所以修复方法是将形状更改为[2,3]
–
代码 –
%tensorflow_version 2.ximport tensorflow as tft1 = tf.Variable([1,2,2,4,5,6])t2 = tf.reshape(t1, [2,3])print(t2)
输出 –
tf.Tensor([[1 2 2] [4 5 6]], shape=(2, 3), dtype=int32)
要解决你的问题,可以提取(tf.image.extract_patches
)你试图tf.reshape
的大小的补丁,或者将tf.reshape
更改为提取补丁的大小。
我还建议你查看其他tf.image
功能,如tf.image.central_crop和tf.image.crop_and_resize。