Keras + Tensorflow: ‘ConvLSTM2D’ 对象没有属性 ‘outbound_nodes’

我在尝试将 ConvLSTM 作为我现有 TensorFlow 网络的一部分,因为在使用 TensorFlow 的 ConvLSTM 实现时遇到了一些问题,所以我选择使用 Keras 的 ConvLSTM2D 层来代替。

为了在我的 TensorFlow 会话中使用 Keras,我采用了博客文章中的建议(我使用的是 TensorFlow 后端):https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html

以下是我代码的一部分(导致问题的部分):

# state 的形状为 [1, 75, 32, 32],批次大小为1state = tf.concat([screen, screen2, non_spatial], axis=1)# 重塑 state 以使时间维度为1,以便适应 ConvLSTM 的形状state_reshaped = tf.reshape(state, [1, 1, 75, 32, 32])# Keras ConvLSTM2D 层# 我尝试过省略输入形状中的批次大小,但这对错误没有影响,似乎也没问题lstm_layer = ConvLSTM2D(filters=5, kernel_size=(3, 3), input_shape=(1, 1, 75, 32, 32), data_format='channels_first', stateful=True)(state_reshaped)fc1 = layers.fully_connected(inputs=layers.flatten(lstm_layer), num_outputs=256, activation_fn=tf.nn.relu)

这导致了以下错误:AttributeError: 'ConvLSTM2D' object has no attribute 'outbound_nodes’”

我不知道这是什么意思。我认为可能是混合使用 Keras 的 ConvLSTM 和 TensorFlow 的 flatten 函数造成的。所以我尝试使用 Keras 的 Flatten() 来代替,像这样:

# lstm_layer 的形状为 (5, 5, 30, 30)lstm_layer = Flatten(data_format='channels_first')(lstm_layer)fc1 = layers.fully_connected(inputs=lstm_layer, num_outputs=256, activation_fn=tf.nn.relu)

然后得到了以下错误:ValueError: The last dimension of the inputs to 'Dense' should be defined. Found 'None'.这个错误是由 Flatten() 引起的,出于某种原因,它的输出形状为 (?, ?),而全连接层需要最后一个维度有定义的形状,但我不知道为什么它会是未定义的。之前它是有定义的。使用 Reshape((4500,))(lstm_layer) 代替会得到同样的 no attribute 'outbound_nodes' 错误。

我在谷歌上搜索了这个问题,似乎我不是唯一遇到这个问题的人,但我找不到解决方案。

我该如何解决这个问题?Flatten() 的未知输出形状是 bug 还是预期行为,如果是,为什么?


回答:

我遇到了同样的问题,并对 TensorFlow 代码进行了一些深入研究。问题在于 Keras 2.2.0 进行了一些重构,而 tf.keras 尚未更新到这个新 API。

‘outbound_nodes’ 属性在 Keras 2.2.0 中被重命名为 ‘_outbound_nodes’。修复这个问题很简单,你需要更新 base.py 中的两个引用:

/site-packages/tensorflow/python/layers/base.py

更新后,我的代码运行正常了。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注