我在尝试在Keras中实现一个模型时遇到了以下错误:
你必须为占位符张量提供值
这是我的模型:
def create_base_network(input_shape, out_dims): model = Sequential() model.add(Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape)) model.add(Conv2D(64, (3, 3), activation='relu')) model.add(MaxPooling2D(pool_size=(2, 2))) model.add(Dropout(0.25)) model.add(Flatten()) model.add(Dense(128, activation='relu')) model.add(Dropout(0.5)) model.add(Dense(out_dims, activation='linear')) return modelinput_shape=(28,28,3)anchor_in = Input(shape=input_shape)pos_in = Input(shape=input_shape)neg_in = Input(shape=input_shape)base_network = create_base_network(input_shape, 128)anchor_out = base_network(anchor_in)pos_out = base_network(pos_in)neg_out = base_network(neg_in)merged = concatenate([anchor_out, pos_out, neg_out], axis=-1)model = Model(inputs=[anchor_in, pos_in, neg_in], outputs=merged)
然后我尝试使用以下代码从顺序模型中获取输出:
seq_fun = K.function([model.layers[0].input, model.layers[1].input, model.layers[2].input], [model.layers[3].get_output_at(0)])seq_output = seq_fun([a, p, n])[0]
这些输入来自一个生成器,形式为具有所需形状的numpy数组。接着错误信息是:
InvalidArgumentError: 你必须为占位符张量 'conv2d_1_input' 提供值,其数据类型为 float,形状为 [?,28,28,3] [[{{node conv2d_1_input}} = Placeholder[dtype=DT_FLOAT, shape=[?,28,28,3], _device="/job:localhost/replica:0/task:0/device:GPU:0"]()]] [[{{node dense_2/BiasAdd/_175}} = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/device:CPU:0", send_device="/job:localhost/replica:0/task:0/device:GPU:0", send_device_incarnation=1, tensor_name="edge_102_dense_2/BiasAdd", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/device:CPU:0"]()]]
我不知道发生了什么。
回答:
你创建的顺序模型有四个输出节点。索引为零的那个,即get_output_at(0)
,是直接输入的结果,而其他三个则是通过你定义的输入层之一输入时的输出。显然,第一个输出节点没有连接到你定义的输入层,因此你会得到错误:
你必须为占位符张量提供值…
所以你需要指定另外三个输出节点(索引为1、2或3)作为你自定义函数的输出:
seq_fun = K.function([model.layers[0].input, model.layers[1].input, model.layers[2].input], [model.layers[3].get_output_at(i)]) # i 必须是 1, 2 或 3
作为补充说明,你可以使用模型的inputs
属性更简洁地定义自定义函数:
seq_fun = K.function(model.inputs, [model.layers[3].get_output_at(i)])