在Keras中”Flatten”的作用是什么?

我正在尝试理解Keras中Flatten函数的作用。下面是我的代码,这是一个简单的两层网络。它接受形状为(3, 2)的二维数据,并输出形状为(1, 4)的一维数据:

model = Sequential()model.add(Dense(16, input_shape=(3, 2)))model.add(Activation('relu'))model.add(Flatten())model.add(Dense(4))model.compile(loss='mean_squared_error', optimizer='SGD')x = np.array([[[1, 2], [3, 4], [5, 6]]])y = model.predict(x)print y.shape

这会打印出y的形状为(1, 4)。然而,如果我删除Flatten这一行,那么它会打印出y的形状为(1, 3, 4)。

我不明白这是为什么。根据我对神经网络的理解,model.add(Dense(16, input_shape=(3, 2)))函数创建了一个完全连接的隐藏层,包含16个节点。这些节点中的每一个都连接到3×2输入元素中的每一个。因此,这个第一层的16个输出节点已经是“扁平”的。所以,第一层的输出形状应该是(1, 16)。然后,第二层以此为输入,输出形状为(1, 4)的数据。

那么,如果第一层的输出已经是“扁平”的,并且形状为(1, 16),为什么我还需要进一步扁平化它呢?


回答:

如果你阅读Keras文档中关于Dense的条目,你会看到这样的调用:

Dense(16, input_shape=(5,3))

会生成一个Dense网络,具有3个输入和16个输出,这些输出将独立应用于5个步骤。因此,如果D(x)将3维向量转换为16维向量,你从这一层得到的输出将是一系列向量:[D(x[0,:]), D(x[1,:]),..., D(x[4,:])],形状为(5, 16)。为了达到你指定的行为,你可以先将输入Flatten成一个15维向量,然后应用Dense

model = Sequential()model.add(Flatten(input_shape=(3, 2)))model.add(Dense(16))model.add(Activation('relu'))model.add(Dense(4))model.compile(loss='mean_squared_error', optimizer='SGD')

编辑:由于一些人难以理解,这里有一张解释性的图片:

enter image description here

Related Posts

Keras Dense层输入未被展平

这是我的测试代码: from keras import…

无法将分类变量输入随机森林

我有10个分类变量和3个数值变量。我在分割后直接将它们…

如何在Keras中对每个输出应用Sigmoid函数?

这是我代码的一部分。 model = Sequenti…

如何选择类概率的最佳阈值?

我的神经网络输出是一个用于多标签分类的预测类概率表: …

在Keras中使用深度学习得到不同的结果

我按照一个教程使用Keras中的深度神经网络进行文本分…

‘MatMul’操作的输入’b’类型为float32,与参数’a’的类型float64不匹配

我写了一个简单的TensorFlow代码,但不断遇到T…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注