我正在学习深度强化学习框架Chainer。
我按照一个教程操作,得到了以下代码:
def train_dddqn(env): class Q_Network(chainer.Chain): def __init__(self, input_size, hidden_size, output_size): super(Q_Network, self).__init__( fc1=L.Linear(input_size, hidden_size), fc2=L.Linear(hidden_size, hidden_size), fc3=L.Linear(hidden_size, hidden_size // 2), fc4=L.Linear(hidden_size, hidden_size // 2), state_value=L.Linear(hidden_size // 2, 1), advantage_value=L.Linear(hidden_size // 2, output_size) ) self.input_size = input_size self.hidden_size = hidden_size self.output_size = output_size def __call__(self, x): h = F.relu(self.fc1(x)) h = F.relu(self.fc2(h)) hs = F.relu(self.fc3(h)) ha = F.relu(self.fc4(h)) state_value = self.state_value(hs) advantage_value = self.advantage_value(ha) advantage_mean = (F.sum(advantage_value, axis=1) / float(self.output_size)).reshape(-1, 1) q_value = F.concat([state_value for _ in range(self.output_size)], axis=1) + ( advantage_value - F.concat([advantage_mean for _ in range(self.output_size)], axis=1)) return q_value def reset(self): self.cleargrads() Q = Q_Network(input_size=env.history_t + 1, hidden_size=100, output_size=3) Q_ast = copy.deepcopy(Q) optimizer = chainer.optimizers.Adam() optimizer.setup(Q) epoch_num = 50 step_max = len(env.data) - 1 memory_size = 200 batch_size = 50 epsilon = 1.0 epsilon_decrease = 1e-3 epsilon_min = 0.1 start_reduce_epsilon = 200 train_freq = 10 update_q_freq = 20 gamma = 0.97 show_log_freq = 5 memory = [] total_step = 0 total_rewards = [] total_losses = [] start = time.time() for epoch in range(epoch_num): pobs = env.reset() step = 0 done = False total_reward = 0 total_loss = 0 while not done and step < step_max: # select act pact = np.random.randint(3) if np.random.rand() > epsilon: pact = Q(np.array(pobs, dtype=np.float32).reshape(1, -1)) pact = np.argmax(pact.data) # act obs, reward, done = env.step(pact) # add memory memory.append((pobs, pact, reward, obs, done)) if len(memory) > memory_size: memory.pop(0) # train or update q if len(memory) == memory_size: if total_step % train_freq == 0: shuffled_memory = np.random.permutation(memory) memory_idx = range(len(shuffled_memory)) for i in memory_idx[::batch_size]: batch = np.array(shuffled_memory[i:i + batch_size]) b_pobs = np.array(batch[:, 0].tolist(), dtype=np.float32).reshape(batch_size, -1) b_pact = np.array(batch[:, 1].tolist(), dtype=np.int32) b_reward = np.array(batch[:, 2].tolist(), dtype=np.int32) b_obs = np.array(batch[:, 3].tolist(), dtype=np.float32).reshape(batch_size, -1) b_done = np.array(batch[:, 4].tolist(), dtype=np.bool) q = Q(b_pobs) indices = np.argmax(q.data, axis=1) maxqs = Q_ast(b_obs).data target = copy.deepcopy(q.data) for j in range(batch_size): Q.reset() loss = F.mean_squared_error(q, target) total_loss += loss.data loss.backward() optimizer.update() if total_step % update_q_freq == 0: Q_ast = copy.deepcopy(Q) # epsilon if epsilon > epsilon_min and total_step > start_reduce_epsilon: epsilon -= epsilon_decrease # next step total_reward += reward pobs = obs step += 1 total_step += 1 total_rewards.append(total_reward) total_losses.append(total_loss) if (epoch + 1) % show_log_freq == 0: log_reward = sum(total_rewards[((epoch + 1) - show_log_freq):]) / show_log_freq log_loss = sum(total_losses[((epoch + 1) - show_log_freq):]) / show_log_freq elapsed_time = time.time() - start print('\t'.join(map(str, [epoch + 1, epsilon, total_step, log_reward, log_loss, elapsed_time]))) start = time.time() return Q, total_losses, total_rewardsQ, total_losses, total_rewards = train_dddqn(Environment1(train))
我的问题是如何保存和加载这个已经训练得很好的模型?我知道Keras有一些像model.save和load_model这样的函数。
那么,对于这个Chainer代码,我需要的具体代码是什么呢?
回答:
您可以使用serializer
模块来保存/加载Chainer的模型参数(Chain
类)。
from chainer import serializersQ = Q_Network(input_size=env.history_t + 1, hidden_size=100, output_size=3)Q_ast = Q_Network(input_size=env.history_t + 1, hidden_size=100, output_size=3)# --- 在这里训练Q... ---# 通过保存Q的参数并加载到Q_ast来复制Q的参数到Q_astserializers.save_npz('my.model', Q)serializers.load_npz('my.model', Q_ast)
有关详细信息,请参阅官方文档:
此外,您可以参考chainerrl
,这是一个用于强化学习的Chainer库。
chainerrl
有一个实用函数copy_param
,用于从网络source_link
复制参数到target_link
。