我正在尝试实现一个模型,该模型将灰度图像作为输入,并返回一个数值作为输出。我使用InceptionV3(从头开始训练)作为特征提取器,然后在最后阶段使用一些全连接层进行回归。
这是我的代码:
from keras.applications.inception_v3 import InceptionV3from keras.layers import Input, GlobalAveragePooling2D, Dense, Dropout, Flatten, BatchNormalizationfrom keras.models import Modelfrom keras.metrics import mean_absolute_errorfrom keras.utils import plot_modelinputs = Input(shape=(256, 256, 1))x = BatchNormalization()(inputs)x = InceptionV3(include_top = False, weights = None, input_shape=inputs.shape[1:])(x)x = BatchNormalization()(x)x = GlobalAveragePooling2D()(x)x = Dense(1000, activation = 'relu' )(x)x = Dense(1000, activation = 'relu' )(x)outputs = Dense(1, activation = 'linear' )(x)model = Model(inputs=inputs, outputs=outputs)model.compile(optimizer = 'adam', loss = 'mse', metrics = [mae])model.summary()
现在当我运行代码时,我得到了这个错误:
---------------------------------------------------------------------------TypeError Traceback (most recent call last)<ipython-input-36-50041eb640cc> in <module>() 7 inputs = Input(shape=(256, 256, 1)) 8 x = BatchNormalization()(inputs)----> 9 x = InceptionV3(include_top = False, weights = None, input_shape=inputs.shape[1:])(x) 10 x = BatchNormalization()(x) 11 x = GlobalAveragePooling2D()(x)3 frames/usr/local/lib/python3.6/dist-packages/keras_applications/imagenet_utils.py in _obtain_input_shape(input_shape, default_size, min_size, data_format, require_flatten, weights) 273 default_shape = (input_shape[0], default_size, default_size) 274 else:--> 275 if input_shape[-1] not in {1, 3}: 276 warnings.warn( 277 'This model usually expects 1 or 3 input channels. 'TypeError: unhashable type: 'Dimension'
我不明白是什么引起了这个错误,因为当我使用顺序模型时一切正常。但对于这个函数式模型就不行了。
回答:
inputs.shape
不是一个列表,因此会抛出错误。它会给你一个类型为 tensorflow.python.framework.tensor_shape.TensorShape
的形状,其中包含一个列表,每个维度的类型为 Dimension
print(inputs.shape)# output TensorShape([Dimension(None), Dimension(256), Dimension(256), Dimension(1)])
你可以使用 as_list()
来获取形状列表:
# inputs.shape.as_list()# output [None, 256, 256, 1]x = InceptionV3(include_top = False, weights = None, input_shape=inputs.shape.as_list()[1:])(x)