背景
我正在阅读《Hands on ML》第二部分,希望能清楚地了解在RNN的损失计算中何时使用”outputs”和何时使用”state”。
书中(有书的朋友可以翻到第396页),作者提到,“请注意,全连接层连接到states
张量,该张量仅包含RNN的最终状态”,这里指的是一个在28步中展开的序列分类器。由于states
变量的长度为len(states) == <隐藏层数量>
,在构建深层RNN时,我一直使用states[-1]
仅连接到最后一层的最终状态。例如:
# hidden_layer_architecture = 定义每层神经元数量的整数列表
# 示例:hidden_layer_architecture = [100 for _ in range(5)]
layers = []
for layer_id, n_neurons in enumerate(hidden_layer_architecture):
hidden_layer = tf.contrib.rnn.BasicRNNCell(n_neurons,
activation=tf.nn.tanh,
name=f'hidden_layer_{layer_id}')
layers.append(hidden_layer)
recurrent_hidden_layers = tf.contrib.rnn.MultiRNNCell(layers)
outputs, states = tf.nn.dynamic_rnn(recurrent_hidden_layers,
X_, dtype=tf.float32)
logits = tf.layers.dense(states[-1], n_outputs, name='outputs')
鉴于作者之前的陈述,这种方法运作正常。然而,我不明白何时会使用tf.nn.dynamic_rnn()
的第一个输出outputs
变量。
我查看了这个问题,它很好地回答了细节问题,并提到,“如果你只对单元的最后一个输出感兴趣,你可以只切片时间维度以选择最后一个元素(例如outputs[:, -1, :]
)。”我理解这意味着states[-1] == outputs[:, -1, :]
,但测试后发现这是错误的。为什么会这样呢?如果输出是每个时间步的单元输出,为什么会这样呢?总的来说…
问题
在损失函数中何时使用tf.nn.dynamic_rnn()
的outputs
变量,何时使用states
变量?这如何改变网络的抽象架构?
任何澄清都将不胜感激。
回答:
这基本上可以分解为以下内容:
outputs
:RNN顶层输出的完整序列。这意味着,如果你使用MultiRNNCell
,这只会是顶层单元的输出;这里不包含下层单元的任何内容。
一般来说,使用自定义RNNCell
实现时,这可能几乎是任何东西,但几乎所有标准单元都会在这里返回状态序列,不过你也可以自己编写一个自定义单元,在返回作为outputs之前对状态序列进行一些处理(例如线性变换)。
state
(请注意,这是文档中使用的名称,不是states
)是最后一个时间步的完整状态。一个重要的区别是,在MultiRNNCell
的情况下,这将包含序列中所有单元的最终状态,而不仅仅是顶层单元!此外,这种输出的确切格式/类型根据使用的RNNCell
而变化很大(例如,它可以是一个张量,或者是一组张量…)。
因此,如果你只关心MultiRNNCell
中最后一个时间步的顶层状态,你实际上有两个选项,这两个选项应该是相同的,归结为个人偏好/”清晰度”:
outputs[:, -1, :]
(假设是批量主格式)仅从顶层状态序列中提取最后一个时间步。state[-1]
仅从所有层的最终状态元组中提取顶层状态。
在其他场景中,你可能没有这种选择:
- 如果你确实需要完整的序列输出,你需要使用
outputs
。 - 如果你需要
MultiRNNCell
中较低层的最终状态,你需要使用state
。
至于为什么等式检查失败:如果你实际上使用了==
,我认为这是在检查张量对象的相等性,这些显然是不同的。你可以尝试在一些简单的玩具场景中(小的状态大小/序列长度)检查这两个对象的值——它们应该是相同的。