如何保存DDPG模型?

我尝试使用saver方法保存模型(我在DDPG类中使用save函数保存),但在恢复模型时,结果与我保存的相差甚远(当情节奖励为零时,我保存了模型,代码中的恢复方法被注释掉了)。我的代码如下,包含所有功能。我使用的是Python 3.7,gym 0.16.0和TensorFlow版本1.13.1

import tensorflow as tfimport numpy as npimport gymepsiode_steps = 500# 行动者的学习率lr_a = 0.001# 评论者的学习率lr_c = 0.002gamma = 0.9alpha = 0.01memory = 10000batch_size = 32render = Trueclass DDPG(object):    def __init__(self, no_of_actions, no_of_states, a_bound, ):        self.memory = np.zeros((memory, no_of_states * 2 + no_of_actions + 1), dtype=np.float32)        # 初始化指针指向我们的经验缓冲区        self.pointer = 0        self.sess = tf.Session()        # 初始化OU过程的方差,用于探索策略        self.noise_variance = 3.0        self.no_of_actions, self.no_of_states, self.a_bound = no_of_actions, no_of_states, a_bound,        self.state = tf.placeholder(tf.float32, [None, no_of_states], 's')        self.next_state = tf.placeholder(tf.float32, [None, no_of_states], 's_')        self.reward = tf.placeholder(tf.float32, [None, 1], 'r')        with tf.variable_scope('Actor'):            self.a = self.build_actor_network(self.state, scope='eval', trainable=True)            a_ = self.build_actor_network(self.next_state, scope='target', trainable=False)        with tf.variable_scope('Critic'):            q = self.build_crtic_network(self.state, self.a, scope='eval', trainable=True)            q_ = self.build_crtic_network(self.next_state, a_, scope='target', trainable=False)        self.ae_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Actor/eval')        self.at_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Actor/target')        self.ce_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Critic/eval')        self.ct_params = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='Critic/target')        # 更新目标值        self.soft_replace = [            [tf.assign(at, (1 - alpha) * at + alpha * ae), tf.assign(ct, (1 - alpha) * ct + alpha * ce)]            for at, ae, ct, ce in zip(self.at_params, self.ae_params, self.ct_params, self.ce_params)]        q_target = self.reward + gamma * q_        # 计算TD误差,即实际值与预测值的差异        td_error = tf.losses.mean_squared_error(labels=(self.reward + gamma * q_), predictions=q)        # 使用adam优化器训练评论者网络        self.ctrain = tf.train.AdamOptimizer(lr_c).minimize(td_error, name="adam-ink", var_list=self.ce_params)        a_loss = - tf.reduce_mean(q)        # 使用adam优化器训练行动者网络以最小化损失        self.atrain = tf.train.AdamOptimizer(lr_a).minimize(a_loss, var_list=self.ae_params)        tf.summary.FileWriter("logs2", self.sess.graph)        # 初始化所有变量        self.sess.run(tf.global_variables_initializer())        # saver        self.saver = tf.train.Saver()        # self.saver.restore(self.sess, "Pendulum/nn.ckpt")      def choose_action(self, s):        a = self.sess.run(self.a, {self.state: s[np.newaxis, :]})[0]        a = np.clip(np.random.normal(a, self.noise_variance), -2, 2)        return a    def learn(self):        # 软目标替换        self.sess.run(self.soft_replace)        indices = np.random.choice(memory, size=batch_size)        batch_transition = self.memory[indices, :]        batch_states = batch_transition[:, :self.no_of_states]        batch_actions = batch_transition[:, self.no_of_states: self.no_of_states + self.no_of_actions]        batch_rewards = batch_transition[:, -self.no_of_states - 1: -self.no_of_states]        batch_next_state = batch_transition[:, -self.no_of_states:]        self.sess.run(self.atrain, {self.state: batch_states})        self.sess.run(self.ctrain, {self.state: batch_states, self.a: batch_actions, self.reward: batch_rewards,                                    self.next_state: batch_next_state})    def store_transition(self, s, a, r, s_):        trans = np.hstack((s, a, [r], s_))        index = self.pointer % memory        self.memory[index, :] = trans        self.pointer += 1        if self.pointer > memory:            self.noise_variance *= 0.99995            self.learn()    def build_actor_network(self, s, scope, trainable):        # 行动者DPG        with tf.variable_scope(scope):            l1 = tf.layers.dense(s, 30, activation=tf.nn.tanh, name='l1', trainable=trainable)            a = tf.layers.dense(l1, self.no_of_actions, activation=tf.nn.tanh, name='a', trainable=trainable)            return tf.multiply(a, self.a_bound, name="scaled_a")    def build_crtic_network(self, s, a, scope, trainable):        with tf.variable_scope(scope):            n_l1 = 30            w1_s = tf.get_variable('w1_s', [self.no_of_states, n_l1], trainable=trainable)            w1_a = tf.get_variable('w1_a', [self.no_of_actions, n_l1], trainable=trainable)            b1 = tf.get_variable('b1', [1, n_l1], trainable=trainable)            net = tf.nn.tanh(tf.matmul(s, w1_s) + tf.matmul(a, w1_a) + b1)            q = tf.layers.dense(net, 1, trainable=trainable)            return q    def save(self):        self.saver.save(self.sess, "Pendulum/nn.ckpt")env = gym.make("Pendulum-v0")env = env.unwrappedenv.seed(1)no_of_states = env.observation_space.shape[0]no_of_actions = env.action_space.shape[0]a_bound = env.action_space.highddpg = DDPG(no_of_actions, no_of_states, a_bound)total_reward = []# 设置情节数量no_of_episodes = 300for i in range(no_of_episodes):    # 初始化环境    s = env.reset()    ep_reward = 0    for j in range(epsiode_steps):        env.render()        # 通过添加OU过程的噪声选择行动        a = ddpg.choose_action(s)        # 执行行动并移动到下一个状态s        s_, r, done, info = env.step(a)        # 将过渡存储到我们的经验缓冲区        # 抽样一些经验的小批量并训练网络        ddpg.store_transition(s, a, r, s_)        # 更新当前状态为下一个状态        s = s_        # 添加情节奖励        ep_reward += r        if int(ep_reward) == 0 and i > 150:            ddpg.save()            print("save")            quit()        if j == epsiode_steps - 1:            total_reward.append(ep_reward)            print('Episode:', i, ' Reward: %i' % int(ep_reward))            break

回答:

我通过重写代码并在单独的会话中添加学习函数完全解决了这个问题

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注