如何访问Tensorflow中循环单元的权重?

在深度Q学习任务中,提高稳定性的一种方法是保持一组用于计算Q值目标的网络目标权重,这些权重更新缓慢。在学习过程中,不同时间会使用两组不同的权重进行前向传递。对于普通的DQN,这并不难实现,因为权重是可以使用feed_dict设置的tensorflow变量,例如:

sess = tf.Session()input = tf.placeholder(tf.float32, shape=[None, 5])weights = tf.Variable(tf.random_normal(shape=[5,4], stddev=0.1)bias = tf.Variable(tf.constant(0.1, shape=[4])output = tf.matmul(input, weights) + biastarget = tf.placeholder(tf.float32, [None, 4])loss = ......#在这里,我们明确地将权重设置为缓慢更新的目标权重sess.run(output, feed_dict={input: states, weights: target_weights, bias: target_bias})#使用此输出计算学习过程的目标.....#现在我们运行学习过程,使用最新的权重,以及之前计算的目标sess.run(loss, feed_dict={input: states, target: targets})

我想在DQN的循环版本中使用这种目标网络技术,但我不知道如何访问和设置循环单元内部使用的权重。具体来说,我使用的是tf.nn.rnn_cell.BasicLSTMCell,但我想知道如何对任何类型的循环单元执行此操作。


回答:

BasicLSTMCell不会将其变量作为公共API的一部分公开。我建议您查找这些变量在图中的名称,并使用这些名称进行输入(这些名称不太可能改变,因为它们在检查点中,更改这些名称会破坏检查点兼容性)。

或者,您可以制作一个公开变量的BasicLSTMCell副本。我认为这是最干净的方法。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注