Tensorflow: 如何使用dynamic_rnn从LSTMCell中获取中间单元状态(c)?

默认情况下,函数dynamic_rnn仅输出每个时间点的隐藏状态(称为m),可以通过以下方式获取:

cell = tf.contrib.rnn.LSTMCell(100)rnn_outputs, _ = tf.nn.dynamic_rnn(cell,                                   inputs=inputs,                                   sequence_length=sequence_lengths,                                   dtype=tf.float32)

有没有办法额外获取中间(而不是最终)的单元状态(c)?

一位tensorflow贡献者提到可以通过单元包装器来实现:

class Wrapper(tf.nn.rnn_cell.RNNCell):  def __init__(self, inner_cell):     super(Wrapper, self).__init__()     self._inner_cell = inner_cell  @property  def state_size(self):     return self._inner_cell.state_size  @property  def output_size(self):    return (self._inner_cell.state_size, self._inner_cell.output_size)  def call(self, input, state)    output, next_state = self._inner_cell(input, state)    emit_output = (next_state, output)    return emit_output, next_state

然而,这似乎不起作用。有没有其他建议?


回答:

提出的解决方案对我来说是有效的,但Layer.call方法的规范更通用,因此以下Wrapper应该对API变化更具适应性。试试这个:

class Wrapper(tf.nn.rnn_cell.RNNCell):  def __init__(self, inner_cell):     super(Wrapper, self).__init__()     self._inner_cell = inner_cell  @property  def state_size(self):     return self._inner_cell.state_size  @property  def output_size(self):    return (self._inner_cell.state_size, self._inner_cell.output_size)  def call(self, input, *args, **kwargs):    output, next_state = self._inner_cell(input, *args, **kwargs)    emit_output = (next_state, output)    return emit_output, next_state

这里是测试代码:

n_steps = 2n_inputs = 3n_neurons = 5X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])basic_cell = Wrapper(tf.nn.rnn_cell.LSTMCell(num_units=n_neurons, state_is_tuple=False))outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)print(outputs, states)X_batch = np.array([  # t = 0      t = 1  [[0, 1, 2], [9, 8, 7]], # instance 0  [[3, 4, 5], [0, 0, 0]], # instance 1  [[6, 7, 8], [6, 5, 4]], # instance 2  [[9, 0, 1], [3, 2, 1]], # instance 3])with tf.Session() as sess:  sess.run(tf.global_variables_initializer())  outputs_val = outputs[0].eval(feed_dict={X: X_batch})  print(outputs_val)

返回的outputs是一个包含(?, 2, 10)(?, 2, 5)张量的元组,这些是所有LSTM的状态和输出。请注意,我使用的是LSTMCell的“毕业”版本,来自tf.nn.rnn_cell包,而不是tf.contrib.rnn。另外,请注意设置state_is_tuple=True以避免处理LSTMStateTuple

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注