基于策略的学习无法收敛

我在尝试实现近端策略优化时,遇到了一个非常奇怪的问题。

以下是问题的简化演示:

import numpy as npimport tensorflow as tfraw_probs = tf.get_variable("raw_probs",[4])probs = tf.nn.softmax(raw_probs)actions = tf.placeholder(dtype=tf.int32, shape=[None], name='actions')rewards = tf.placeholder(dtype=tf.float32, shape=[None], name='rewards')old_probs = tf.placeholder(dtype=tf.float32, shape=[None], name='old_probs')new_probs = tf.reduce_sum(probs * tf.one_hot(indices=actions, depth=4))ratios = new_probs / old_probsclipped_ratios = tf.clip_by_value(ratios, clip_value_min=0.8, clip_value_max=1.2)loss_clip = -tf.reduce_mean(tf.minimum(tf.multiply(rewards, ratios), tf.multiply(rewards, clipped_ratios)))optimizer = tf.train.AdamOptimizer()train_pol = optimizer.minimize(loss_clip)with tf.Session() as sess:    sess.run(tf.global_variables_initializer())    for i in range(1000):        input_actions = []        input_rewards = []        input_old_probs = []        for j in range(20):            tmp_probs = sess.run(probs)            if j == 0:                print(tmp_probs)            act = np.random.choice(4,p=tmp_probs)            input_actions.append(act)            if act == 0:                input_rewards.append(1)            else:                input_rewards.append(-1)            input_old_probs.append(tmp_probs[act])        sess.run(train_pol,feed_dict={actions: input_actions,rewards: input_rewards,old_probs: input_old_probs})

程序根据概率分布抽取数字。如果抽到0,则获得1的奖励;如果抽到其他数字,则获得-1的奖励。接着,程序会根据结果调整概率。

理论上,选择0的概率应该始终增加,最终收敛到1。但在实际操作中,这个概率却在下降。

我在这里做错了什么?


回答:

我解决了!我对reduce_sum的效果理解得不够透彻。

只需将

new_probs = tf.reduce_sum(probs * tf.one_hot(indices=actions, depth=4))

改为

new_probs = tf.reduce_sum(probs * tf.one_hot(indices=actions, depth=4),1)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注