使用 tf extract_image_patches 作为 CNN 的输入?

我想从我的原始图像中提取补丁作为 CNN 的输入。经过一些研究,我找到了使用tensorflow.compat.v1.extract_image_patches提取补丁的方法。

由于这些补丁需要重塑为“图像格式”,我实现了一个名为 reshape_image_patches 的方法来重塑它们,并将重塑后的补丁存储在一个数组中。

image_patches2 = []def reshape_image_patches(image_patches, sess, ksize_rows, ksize_cols):    a = sess.run(tf.shape(image_patches))    nr, nc = a[1], a[2]    for i in range(nr):      for j in range(nc):        patch = tf.reshape(image_patches[0,i,j,], [ksize_rows, ksize_cols, 3])        image_patches2.append(patch)    return image_patches2

如何将此方法与 Keras 生成器结合使用,使这些补丁成为我的 CNN 的输入?

编辑 1:

我尝试了加载 TensorFlow 图像并创建补丁中的方法

import tensorflow as tfimport matplotlib.pyplot as pltimport numpy as npdataset = tf.keras.preprocessing.image_dataset_from_directory(    <directory>,    label_mode=None,    seed=1,    subset='training',    validation_split=0.1,    image_size=(900, 900))get_patches = lambda x: (tf.reshape(    tf.image.extract_patches(        x,        sizes=[1, 16, 16, 1],        strides=[1, 8, 8, 1],        rates=[1, 1, 1, 1],        padding='VALID'),  (111*111, 16, 16, 3)))dataset = dataset.map(get_patches)fig = plt.figure()plt.subplots_adjust(wspace=.1, hspace=.2)images = next(iter(dataset))for index, image in enumerate(images):    ax = plt.subplot(2, 2, index + 1)    ax.set_xticks([])    ax.set_yticks([])    ax.imshow(image)plt.show()

在这一行: images = next(iter(dataset)) 我遇到了错误: InvalidArgumentError: 输入到 reshape 的是一个包含 302800896 个值的张量,但请求的形状只有 9462528[[{{node Reshape}}]]

有人知道如何修复这个问题吗?


回答:

tf.reshape不会改变张量中元素的顺序或元素的总数。如错误所述,你试图将总元素数从 302800896 减少到 9462528。你在lambda函数中使用了tf.reshape

在下面的示例中,我重新创建了你的场景,我为tf.reshapeshape参数设置为2,这不能容纳原始张量的所有元素,因此抛出了错误 –

代码 –

%tensorflow_version 2.ximport tensorflow as tft1 = tf.Variable([1,2,2,4,5,6])t2 = tf.reshape(t1, 2)

输出 –

---------------------------------------------------------------------------InvalidArgumentError                      Traceback (most recent call last)<ipython-input-3-0ff1d701ff22> in <module>()      3 t1 = tf.Variable([1,2,2,4,5,6])      4 ----> 5 t2 = tf.reshape(t1, 2)3 frames/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)InvalidArgumentError: Input to reshape is a tensor with 6 values, but the requested shape has 2 [Op:Reshape]

tf.reshape应该以这样的方式进行,使元素的排列可以改变,但总元素数必须保持不变。所以修复方法是将形状更改为[2,3]

代码 –

%tensorflow_version 2.ximport tensorflow as tft1 = tf.Variable([1,2,2,4,5,6])t2 = tf.reshape(t1, [2,3])print(t2)

输出 –

tf.Tensor([[1 2 2] [4 5 6]], shape=(2, 3), dtype=int32)

要解决你的问题,可以提取(tf.image.extract_patches)你试图tf.reshape的大小的补丁,或者将tf.reshape更改为提取补丁的大小。

我还建议你查看其他tf.image功能,如tf.image.central_croptf.image.crop_and_resize

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注