我在训练一个DQN来玩OpenAI的Atari环境,但我的网络的Q值很快就飙升到远超现实的水平。
以下是相关代码部分:
for state, action, reward, next_state, done in minibatch: if not done: # 为了节省内存,next_state仅为一帧 # 因此我们需要将其添加到当前状态以获取网络的实际输入 next_4_states = np.array(state) next_4_states = np.roll(next_4_states, 1, axis=3) next_4_states[:, :, :, 0] = next_state target = reward + self.gamma * \ np.amax(self.target_model.predict(next_4_states)) else: target = reward target_f = self.target_model.predict(state) target_f[0][action] = target self.target_model.fit(state, target_f, epochs=1, verbose=0)
折扣因子为0.99(使用0.9的折扣因子不会发生这种情况,但也不会收敛,因为它无法提前足够远地思考)。
通过逐步调试代码,我发现问题的原因是所有不应该更新的Q值(即我们未采取的动作的Q值)都会略有增加。我的理解是,在训练过程中将网络自己的输出传递给网络应该保持输出不变,而不是增加或减少它。我的模型有什么问题吗?有什么方法可以屏蔽更新,使其只更新相关的Q值吗?
编辑:我的模型创建代码如下:
def create_model(self, input_shape, num_actions, learning_rate): model = Sequential() model.add(Convolution2D(32, 8, strides=(4, 4), activation='relu', input_shape=(input_shape))) model.add(Convolution2D(64, 4, strides=(2, 2), activation='relu')) model.add(Convolution2D(64, 3, strides=(1, 1), activation='relu')) model.add(Flatten()) model.add(Dense(512, activation='relu')) model.add(Dense(num_actions)) model.compile(loss='mse', optimizer=Adam(lr=learning_rate)) return model
我创建了两个这样的模型。一个用于在线网络,一个用于目标网络。
回答:
哪些预测会被更新?
通过逐步调试代码,我发现问题的原因是所有不应该更新的Q值(即我们未采取的动作的Q值)都会略有增加。我的理解是,在训练过程中将网络自己的输出传递给网络应该保持输出不变,而不是增加或减少它。
下面我画了一个非常简单的具有3个输入节点、3个隐藏节点和3个输出节点的神经网络。假设你只为第一个动作设置了一个新的目标,并且简单地再次使用现有的预测作为其他动作的目标。这导致只有第一个动作/输出的错误非零(为简单起见,我假设大于零)(在图像中用delta
表示),而其他动作的错误为0
。
我用粗体标出了从输出层到隐藏层传播此错误的连接。请注意,每个隐藏层中的节点仍然会收到错误。当这些节点将它们的错误传播回输入层时,它们将通过输入层和隐藏层之间的所有连接来完成,因此所有这些权重都可能被修改。
所以,假设所有这些权重都被更新了,现在假设使用原始输入进行新的前向传递。你期望输出节点2和3的输出与之前完全相同吗?不,可能不会;从隐藏节点到最后两个输出的连接可能仍然具有相同的权重,但所有三个隐藏节点的激活水平都会有所不同。因此,其他输出并不能保证保持不变。
有什么方法可以屏蔽更新,使其只更新相关的Q值吗?
不容易做到,如果能做到的话也是如此。问题在于,除了最后一对连接之外,层与层之间的连接并不是特定于动作的,我认为你也不希望它们是特定于动作的。
目标网络
我的模型有什么问题吗?
我注意到的一件事是你似乎在更新用于生成目标的同一个网络:
target_f = self.target_model.predict(state)
和
self.target_model.fit(state, target_f, epochs=1, verbose=0)
都使用了self.target_model
。你应该为这两行使用网络的独立副本,并且仅在较长时间后将更新后的网络的权重复制到用于计算目标的网络中。关于这一点的更多信息,请参见这篇文章中的附加内容3。
双重DQN
除此之外,众所周知,DQN仍然可能倾向于高估Q值(尽管通常不会完全爆炸)。这可以通过使用双重DQN来解决(注意:这是后来在DQN之上添加的改进)。