在TensorFlow中批处理中每个项目的LSTM初始状态

我在TensorFlow中使用tf.nn.dynamic_rnn运行LSTM。我有一个包含N个初始状态向量的张量和一个包含M = N * n个输入的张量。每组序列由n个输入项组成,我希望用第i个初始状态向量评估第i组输入向量,如下所示:

inputs[0:n], initial_states[0]inputs[n:2*n], initial_states[1]...

是否可以直接使用tf.nn.dynamic_rnn和上述张量进行一次调用,还是我必须对每个初始状态向量及其对应的输入使用循环(导致len(initial_states)次对tf.nn.dynamic_rnn的调用)?


回答:

(从问题的评论中添加一些细节)

这种批处理方式得到了很好的支持,并且通常是获得良好性能的必要条件。您的initial_state将有一个跨N的批次维度,RNN将在这些批次上运行n步。您只需将输入重塑为[N, n, …](在time_major=False的情况下,这是默认设置)。

当您有需要一起批处理的可变长度输入时,情况会变得更加复杂。像SequenceQueueingStateSaver这样的工具可以帮助解决这个问题。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注