如何在TensorFlow中设置RNN状态,当state_is_tuple=True时?

我编写了一个使用TensorFlow的RNN语言模型。该模型被实现为一个RNN类。图结构在构造函数中构建,而RNN.trainRNN.test方法则运行它。

我想在训练集中移动到新文档时,或者在训练过程中运行验证集时,能够重置RNN状态。我通过在训练循环中管理状态,并通过一个feed字典将其传递到图中来实现这一点。

在构造函数中,我这样定义RNN

    cell = tf.nn.rnn_cell.LSTMCell(hidden_units)    rnn_layers = tf.nn.rnn_cell.MultiRNNCell([cell] * layers)    self.reset_state = rnn_layers.zero_state(batch_size, dtype=tf.float32)    self.state = tf.placeholder(tf.float32, self.reset_state.get_shape(), "state")    self.outputs, self.next_state = tf.nn.dynamic_rnn(rnn_layers, self.embedded_input, time_major=True,                                                  initial_state=self.state)

训练循环看起来像这样

 for document in document:     state = session.run(self.reset_state)     for x, y in document:          _, state = session.run([self.train_step, self.next_state],                                  feed_dict={self.x:x, self.y:y, self.state:state})

xy是文档中的训练数据批次。想法是在每个批次后传递最新的状态,除了当我开始一个新文档时,我通过运行self.reset_state来将状态清零。

这一切都有效。现在我想将我的RNN更改为使用推荐的state_is_tuple=True。然而,我不知道如何通过feed字典传递更复杂的LSTM状态对象。我也不知道在构造函数中的self.state = tf.placeholder(...)行应该传递什么参数。

这里正确的策略是什么?目前关于dynamic_rnn的示例代码或文档仍然不多。


TensorFlow的问题26952838似乎相关。

WILDML上的一篇博客文章讨论了这些问题,但没有直接给出答案。

另见TensorFlow: Remember LSTM state for next batch (stateful LSTM)


回答:

TensorFlow占位符的一个问题是,你只能用Python列表或Numpy数组来馈送它(我认为)。所以你不能在LSTMStateTuple的元组之间保存状态。

我通过像这样在张量中保存状态来解决这个问题

initial_state = np.zeros((num_layers, 2, batch_size, state_size))

LSTM层中有两个组件,细胞状态隐藏状态,这就是“2”的来源。(这篇文章很好:https://arxiv.org/pdf/1506.00019.pdf

在构建图时,你像这样解包并创建元组状态:

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])l = tf.unpack(state_placeholder, axis=0)rnn_tuple_state = tuple(         [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])          for idx in range(num_layers)])

然后你以通常的方式获取新状态

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state)

不应该这样…也许他们正在寻找解决方案。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注