我的输入数据如下所示:
[[1, 2, 3][4, 5, 6][7, 8, 9]...]
形状为 (1, num_samples, num_features)
,标签如下所示:
[[0, 1][1, 0][1, 0]...]
形状为 (1, num_samples, 2)
。
然而,当我尝试运行以下Keras代码时,出现了这个错误:ValueError: Error when checking model target: expected dense_1 to have 2 dimensions, but got array with shape (1, 8038, 2)
。据我所知,这似乎是因为我的标签是二维的,而不是简单的整数。这是正确的吗?如果是的话,如何在Keras中使用独热编码标签?
以下是代码:
num_features = 463trX = np.random(8038, num_features)trY = # one-hot array of shape (8038, 2) as described abovedef keras_builder(): #generator to build the inputs while(1): x = np.reshape(trX, (1,) + np.shape(trX)) y = np.reshape(trY, (1,) + np.shape(trY)) print(np.shape(x)) # (1, 8038, 463) print(np.shape(y)) # (1, 8038, 2) yield x, ymodel = Sequential()model.add(LSTM(100, input_dim = num_features))model.add(Dense(1, activation='sigmoid'))model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])model.fit_generator(keras_builder(), samples_per_epoch = 1, nb_epoch=3, verbose = 2, nb_worker = 1)
这会立即抛出上述错误:
Traceback (most recent call last): File "file.py", line 35, in <module> model.fit_generator(keras_builder(), samples_per_epoch = 1, nb_epoch=3, verbose = 2, nb_worker = 1) ...ValueError: Error when checking model target: expected dense_1 to have 2 dimensions, but got array with shape (1, 8038, 2)
谢谢!
回答:
有很多事情似乎对不上号。
我假设你正在尝试解决一个序列分类任务,即你的数据形状为 (<batch size>, <sequence length>, <feature length>)
。
在你的批处理生成器中,你创建了一个批次,包含一个长度为8038的序列,每个序列元素有463个特征。你创建了一个匹配的Y批次进行比较,包含一个序列,包含8038个元素,每个元素大小为2。
你的问题是 Y
与最后一层的输出不匹配。你的 Y
是三维的,而你的模型输出只是二维的:Y.shape = (1, 8038, 2)
与 dense_1.shape = (1,1)
不匹配。这解释了你得到的错误消息。
解决方案是:你需要在LSTM层中启用 return_sequences=True
,以返回一个序列而不是仅返回最后一个元素(有效地移除时间维度)。这将在LSTM层产生一个输出形状为 (1, 8038, 100)
。由于 Dense
层无法处理序列数据,你需要将其应用于每个序列元素,这可以通过将其包装在 TimeDistributed
包装器中来实现。这样,你的模型将具有输出形状 (1, 8038, 1)
。
你的模型应该如下所示:
from keras.layers.wrappers import TimeDistributedmodel = Sequential()model.add(LSTM(100, input_dim=num_features, return_sequences=True))model.add(TimeDistributed(Dense(1, activation='sigmoid')))
检查模型摘要时可以轻松发现这一点:
print(model.summary())