TensorFlow dynamic_rnn 状态

我的问题是关于 TensorFlow 的方法 tf.nn.dynamic_rnn。它返回每个时间步的输出和最终状态。

我想知道返回的最终状态是单元在最大序列长度时的状态,还是由 sequence_length 参数单独决定的。

为了更好地理解,这里有一个例子:我有三个序列,长度为 [10,20,30],并得到最终状态 [3,512](如果单元的隐藏状态长度为 512)。

这三个序列的三个返回的隐藏状态是单元在第30个时间步的状态,还是我得到了时间步 [10,20,30] 的状态?


回答:

tf.nn.dynamic_rnn 返回两个张量:outputsstates

outputs 包含批次中所有序列的所有单元的输出。因此,如果某个序列较短并用零填充,最后几个单元的 outputs 将为零。

states 包含最后一个单元的状态,或者等同于每个序列的最后一个非零输出(如果您使用的是 BasicRNNCell)。

这里有一个例子:

import numpy as npimport tensorflow as tfn_steps = 2n_inputs = 3n_neurons = 5X = tf.placeholder(dtype=tf.float32, shape=[None, n_steps, n_inputs])seq_length = tf.placeholder(tf.int32, [None])basic_cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons)outputs, states = tf.nn.dynamic_rnn(basic_cell, X, sequence_length=seq_length, dtype=tf.float32)X_batch = np.array([  # t = 0      t = 1  [[0, 1, 2], [9, 8, 7]], # instance 0  [[3, 4, 5], [0, 0, 0]], # instance 1])seq_length_batch = np.array([2, 1])with tf.Session() as sess:  sess.run(tf.global_variables_initializer())  outputs_val, states_val = sess.run([outputs, states],                                      feed_dict={X: X_batch, seq_length: seq_length_batch})  print('outputs:')  print(outputs_val)  print('\nstates:')  print(states_val)

这将打印类似以下内容:

outputs:[[[-0.85381496 -0.19517037  0.36011398 -0.18617202  0.39162001]  [-0.99998015 -0.99461144 -0.82241321  0.93778896  0.90737367]] [[-0.99849552 -0.88643843  0.20635395  0.157896    0.76042926]  [ 0.          0.          0.          0.          0.        ]]]  # 因为长度为1states:[[-0.99998015 -0.99461144 -0.82241321  0.93778896  0.90737367] [-0.99849552 -0.88643843  0.20635395  0.157896    0.76042926]]

请注意,states 包含与 output 中相同的向量,它们是每个批次实例的最后一个非零输出。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注