我正在研究recurrent.py
文件中LSTMCell
类的代码(https://github.com/keras-team/keras/blob/master/keras/layers/recurrent.py)
这个类是否计算单个时间步的隐藏状态和携带状态?
我可以在哪里找到处理展开网络的代码,也就是从一个时间步到另一个时间步的代码?
我试图计算单个示例中每个时间步每个门的输出。到目前为止,我可以从训练好的网络中提取权重和偏置,并按照从第1828行到1858行的代码计算激活值。特别是:
i = self.recurrent_activation(x_i + K.dot(h_tm1_i, self.recurrent_kernel_i))f = self.recurrent_activation(x_f + K.dot(h_tm1_f, self.recurrent_kernel_f))c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c, self.recurrent_kernel_c))o = self.recurrent_activation(x_o + K.dot(h_tm1_o, self.recurrent_kernel_o))
我的输入形状为:input(seq_length, nb_dim)。那么,为了正确计算每个门的输出,我应该做类似以下的事情吗:
for step in range(seq_length): input_step = input[step, :] x_i = np.dot(input_step, kernel_i) + bias_i i = recurrent_activation(x_i + np.dot(h_tm1_i, recurrent_kernel_i) <<< 对其他门重复此操作 >>> <<< 计算单元的隐藏状态/携带状态 >>>
回答:
我可以在哪里找到处理展开网络的代码,也就是从一个时间步到另一个时间步的代码?
这种逻辑是由keras.backend.rnn
函数完成的(recurrent.py
):
last_output, outputs, states = K.rnn(step, inputs, initial_state, constants=constants, go_backwards=self.go_backwards, mask=mask, unroll=self.unroll, input_length=timesteps)
step
基本上是调用一个单元…
def step(inputs, states): return self.cell.call(inputs, states, **kwargs)
…在LSTM单元的情况下,它计算i
、f
、c
和o
门,正如您在问题中描述的那样,并从中评估输出和状态张量。
如果您使用的是tensorflow后端,可以在keras/backend/tensorflow_backend.py
中找到实际遍历输入序列的循环。