[深度Q网络]如何在Tensorflow的自动微分中排除操作

我正在尝试使用Tensorflow创建一个类似于Deepmind DQN3.0的深度Q网络(DQN),但我遇到了一些困难。我认为问题出在Tensorflow的自动微分方法上。

请看这张图。这是DQN3.0的架构。

DQN架构

在监督学习中,为了使网络的输出接近标签,我们通过损失函数计算差异,然后进行反向传播,并使用优化器更新参数。

在DQN中,AI过去经历的状态被存储在内存中,并再次输入到两个神经网络TergetNetwork和Network中,两个网络之间的差异反映在Network中。

每个网络的输出不是总和为1的概率,而是期望值。TergetNetwork的输出将包括折扣率(gamma)和当时获得的奖励。

此外,查看DQN 3.0的实现(lua + torch),它将当前网络的输出与当时选择的动作进行比较,并通过backward方法直接反向传播差异。

function nql:getQUpdate(args)    local s, a, r, s2, term, delta    local q, q2, q2_max    s = args.s    a = args.a    r = args.r    s2 = args.s2    term = args.term    -- 调用forward的顺序有点奇怪,为了避免不必要的调用(我们只需要2个)。    -- delta = r + (1-terminal) * gamma * max_a Q(s2, a) - Q(s, a)    term = term:clone():float():mul(-1):add(1)    local target_q_net    if self.target_q then        target_q_net = self.target_network    else        target_q_net = self.network    end    -- 计算max_a Q(s_2, a)。    q2_max = target_q_net:forward(s2):float():max(2)    -- 计算q2 = (1-terminal) * gamma * max_a Q(s2, a)    q2 = q2_max:clone():mul(self.discount):cmul(term)    delta = r:clone():float()    if self.rescale_r then        delta:div(self.r_max)    end    delta:add(q2)    -- q = Q(s,a)    local q_all = self.network:forward(s):float()    q = torch.FloatTensor(q_all:size(1))    for i=1,q_all:size(1) do        q[i] = q_all[i][a[i]]    end    delta:add(-1, q)    if self.clip_delta then        delta[delta:ge(self.clip_delta)] = self.clip_delta        delta[delta:le(-self.clip_delta)] = -self.clip_delta    end    local targets = torch.zeros(self.minibatch_size, self.n_actions):float()    for i=1,math.min(self.minibatch_size,a:size(1)) do        targets[i][a[i]] = delta[i]    end    if self.gpu >= 0 then targets = targets:cuda() end    return targets, delta, q2_maxendfunction nql:qLearnMinibatch()    -- 执行一个小批量Q学习更新:    -- w += alpha * (r + gamma max Q(s2,a2) - Q(s,a)) * dQ(s,a)/dw    assert(self.transitions:size() > self.minibatch_size)    local s, a, r, s2, term = self.transitions:sample(self.minibatch_size)    local targets, delta, q2_max = self:getQUpdate{s=s, a=a, r=r, s2=s2,    term=term, update_qmax=true}    -- 清零参数的梯度    self.dw:zero()    -- 获取新的梯度    self.network:backward(s, targets)

因此,如果不介意上图中计算块的速度,可以使用Numpy等在CPU上计算,而不是使用Tensorflow,这样我就可以将其从自动微分中排除,我正在考虑这样做。

在DQN3.0中,反向传播仅从Network的输出层(蓝色部分)计算。然而,在我的Tensorflow模型中,它是从最终操作mul开始的。

我想使用Tensorflow从与DQN3.0相同的输出层开始反向传播。

我知道可以使用compute_gradients()优化器方法获取grads_and_vars,并运行从头开始创建的手动微分过程。但我认为实现卷积层的这种微分对我来说非常困难。

我能否使用Tensorflow函数或其他方法在自动微分中排除计算块操作?或者有其他解决方法吗?


回答:

谢谢大家。

我暂时解决了排除问题。

我创建了一个修改后的tf.gradients函数的原始函数,如下所示。

def gradients(ys,          xs,          grad_start, #***** ←添加新的梯度开始操作的参数 ******          grad_ys=None,          name="gradients",          colocate_gradients_with_ops=False,          gate_gradients=False,          aggregation_method=None):              .              .              .          # 'from_ops'集合。          stop_ops = _StopOps(from_ops, pending_count)          while queue:            # 为操作生成梯度子图。            op = queue.popleft()            with _maybe_colocate_with(op, colocate_gradients_with_ops):              if loop_state:                loop_state.EnterGradWhileContext(op, before=True)              out_grads = _AggregatedGrads(grads, op, loop_state, aggregation_method)              if loop_state:                loop_state.ExitGradWhileContext(op, before=True)              #*************************************************************              # 添加2行,将'out_grads'操作替换为新的梯度开始操作              if grad_start is not None and op == grad_start.op:                  out_grads = ys              .              .              .

我在Tensorboard上确认过,看起来符合预期。https://i.sstatic.net/vG1e0.png

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注