默认情况下,函数dynamic_rnn
仅输出每个时间点的隐藏状态(称为m
),可以通过以下方式获取:
cell = tf.contrib.rnn.LSTMCell(100)rnn_outputs, _ = tf.nn.dynamic_rnn(cell, inputs=inputs, sequence_length=sequence_lengths, dtype=tf.float32)
有没有办法额外获取中间(而不是最终)的单元状态(c
)?
一位tensorflow
贡献者提到可以通过单元包装器来实现:
class Wrapper(tf.nn.rnn_cell.RNNCell): def __init__(self, inner_cell): super(Wrapper, self).__init__() self._inner_cell = inner_cell @property def state_size(self): return self._inner_cell.state_size @property def output_size(self): return (self._inner_cell.state_size, self._inner_cell.output_size) def call(self, input, state) output, next_state = self._inner_cell(input, state) emit_output = (next_state, output) return emit_output, next_state
然而,这似乎不起作用。有没有其他建议?
回答:
提出的解决方案对我来说是有效的,但Layer.call
方法的规范更通用,因此以下Wrapper
应该对API变化更具适应性。试试这个:
class Wrapper(tf.nn.rnn_cell.RNNCell): def __init__(self, inner_cell): super(Wrapper, self).__init__() self._inner_cell = inner_cell @property def state_size(self): return self._inner_cell.state_size @property def output_size(self): return (self._inner_cell.state_size, self._inner_cell.output_size) def call(self, input, *args, **kwargs): output, next_state = self._inner_cell(input, *args, **kwargs) emit_output = (next_state, output) return emit_output, next_state
这里是测试代码:
n_steps = 2n_inputs = 3n_neurons = 5X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])basic_cell = Wrapper(tf.nn.rnn_cell.LSTMCell(num_units=n_neurons, state_is_tuple=False))outputs, states = tf.nn.dynamic_rnn(basic_cell, X, dtype=tf.float32)print(outputs, states)X_batch = np.array([ # t = 0 t = 1 [[0, 1, 2], [9, 8, 7]], # instance 0 [[3, 4, 5], [0, 0, 0]], # instance 1 [[6, 7, 8], [6, 5, 4]], # instance 2 [[9, 0, 1], [3, 2, 1]], # instance 3])with tf.Session() as sess: sess.run(tf.global_variables_initializer()) outputs_val = outputs[0].eval(feed_dict={X: X_batch}) print(outputs_val)
返回的outputs
是一个包含(?, 2, 10)
和(?, 2, 5)
张量的元组,这些是所有LSTM的状态和输出。请注意,我使用的是LSTMCell
的“毕业”版本,来自tf.nn.rnn_cell
包,而不是tf.contrib.rnn
。另外,请注意设置state_is_tuple=True
以避免处理LSTMStateTuple
。