我正在尝试按照TensorFlow文档中提到的方式对二分类图像分类问题应用数据增强:https://www.tensorflow.org/tutorials/images/classification#data_augmentation
我的模型如下:
Sequential([ data_augmentation, layers.experimental.preprocessing.Rescaling(1./255), layers.Conv2D(16, 3, padding='same', activation='relu'), layers.MaxPooling2D(), layers.Dropout(0.2), layers.Conv2D(32, 3, padding='same', activation='relu'), layers.MaxPooling2D(), layers.Dropout(0.2), layers.Conv2D(64, 3, padding='same', activation='relu'), layers.MaxPooling2D(), layers.Flatten(), layers.Dense(128, activation='relu'), layers.Dropout(0.5), layers.Dense(1, activation='sigmoid')])
当我的数据增强层如下设置时,模型可以无错误地编译:
data_augmentation = keras.Sequential( [ layers.experimental.preprocessing.RandomFlip("horizontal", input_shape=(150, 150, 3)), layers.experimental.preprocessing.RandomRotation(0.2), layers.experimental.preprocessing.RandomZoom(0.2) ])
如果我在增强层中尝试引入RandomHeight()
和/或RandomWidth()
,在创建模型时会收到以下错误:
ValueError: The last dimension of the inputs to `Dense` should be defined. Found `None`.
您知道这是为什么吗?如何解决这个问题?
回答:
您可以检查RandomWidth-Height
输出的形状。RandomWidth类的源代码:
return tensor_shape.TensorShape( [input_shape[0], None, input_shape[2], input_shape[3]])
假设我将RandomHeight
作为第一层使用,并且input_shape
为150 x 150的RGB图像。我们可以通过以下方式确认输出形状:
data_augmentation.summary()Layer (type) Output Shape Param # =================================================================random_height_2 (RandomHeigh (None, None, 150, 3) 0 _________________________________________________________________random_flip_2 (RandomFlip) (None, None, 150, 3) 0 _________________________________________________________________random_rotation_2 (RandomRot (None, None, 150, 3) 0 _________________________________________________________________random_zoom_2 (RandomZoom) (None, None, 150, 3) 0
当您这样使用它,并且如果您在没有密集层的情况下编译模型,您将在模型摘要中看到:
dropout_6 (Dropout) (None, None, 18, 64) 0 _________________________________________________________________flatten_6 (Flatten) (None, None) 0
(None,None)
在这里引起了错误。您可以通过使用tf.keras.layers.GlobalMaxPooling2D()
代替Flatten()
来解决这个问题
虽然这解决了Flatten()
层引起的维度问题,但GlobalMaxPooling2D
的行为略有不同。
您可以查看这个问题来了解差异。