Prioritized Replay DQN与Dueling DQN

在前面的文章Nature DQN与Double DQN中介绍了对DQN算法的两种改进,Nature DQN加快了收敛速度,Double DQN有效的解决了过拟合的现象,本篇文章将继续介绍一些对于DQN算法的改进。

Prioritized Replay DQN

Prioritized Replay DQN对于DQN的改进主要是针对训练过程中样本的选择。在以往的训练中,训练样本被存放在记忆池中,在更新参数的时候随机取出一组来更新。由于是随机取出的,可能会导致一些对训练有促进作用的样本学习的不够充分(例如TD误差较大的样本),而某些无意义的样本却多次被学习(例如TD误差较小的样本),影响收敛速度。因此,Prioritized Replay DQN提出了一种对训练样本进行有选择训练的方法。

Prioritized Replay DQN会对每一个样本的TD误差计算绝对值δ(t)|\delta(t)|,并根据δ(t)|\delta(t)|的大小设置优先级,使得δ(t)|\delta(t)|越高的样本更容易得到训练。这样以来,记忆池中需要保存的样本属性就增加了一个优先级了,为了更方便的处理样本,记忆池需要采用其他的数据结构进行处理。很容易想到的就是数据结构优先队列了,但是优先队列存在着一个问题,那就是不好更新样本的优先级。

由于网络是在不断的学习的,随着网络的能力越来越强,原本比较难以学习的样本会变得不再困难,所以需要不断的更新样本的优先级。为了解决这个问题,Prioritized Replay DQN使用SumTree树结构来存储记忆,如下图所示:

SumTree只有叶子节点保存样本以及样本的优先级(上图中节点上的数字),非叶子节点存储的是字节点优先级之和。也就是说,只要叶子节点优先级发生更新,相应的父节点只要冒泡更新即可。

将样本存到SumTree之后,可以按照如下步骤来取出样本,以上图为例:

  1. 从0到42随机选择一个数字,例如26。
  2. 从根节点开始搜索到叶子节点,找到26所属的区间对应的叶子节点。即25-29对应的节点4。
  3. 返回样本

注意到,每个叶子节点所对应的区间是不同的,且区间大小正比于优先级。这样以来,在进行样本选择时,区间大的叶子节点被选择的概率就会大一些,区间小的叶子节点被选择的概率就会小一些。因此,在采样次数足够的情况下,每个样本都可以得到学习,同时重要的样本也会有所偏好。

除此之外,在目标函数上,Prioritized Replay DQN也作了调整,如下所示:

Jθ=(yjwjQ(ϕj,aj;θ))2J_\theta=(y_j-w_jQ(\phi_j,a_j;\theta))^2

其中wjw_j表示相应的优先级。其他地方与Double DQN没什么区别,其算法流程如下:

Dueling DQN

Prioritized Replay DQN从样本选择的角度优化了DQN算法,Dueling DQN则从网络结构优化了DQN。之前的DQN网络都是之间输出的QQ值,而Dueling DQN则不然。它将网络的输出作为两个分支,一个分支输出仅仅与状态相关的价值函数V(s;θ,β)V(s;\theta,\beta),另一个分支输出与状态和动作都相关的优势函数A(s,a;θ,α)A(s,a;\theta,\alpha)。然后将两者相加,得到最终的输出,如下式所示:

Q(s,a;θ,α,β)=V(s;θ,β)+A(s,a;θ,α)Q(s,a;\theta,\alpha,\beta)=V(s;\theta,\beta)+A(s,a;\theta,\alpha)

其中θ\theta表示两个分支公共的参数,α\alphaβ\beta表示相应的不同的参数,其网络结构如下:

传统DQN(上),Dueling DQN(下)

然而,上面的公式在计算QQ值时会出现一个unidentifiable问题:给定一个QQ,是无法得到唯一的VVAA的。比如,VVAA分别加上和减去一个值能够得到同样的QQ,但反过来显然无法由QQ得到唯一的VVAA。进而导致无法对动作作出选择。

为了解决这个问题,可以强制令所选择贪婪动作的优势函数为0:

Q(s,a;θ,α,β)=V(s;θ,β)+(A(s,a;θ,α)maxaϵA^A(s,a;θ,α))Q(s,a;\theta,\alpha,\beta)=V(s;\theta,\beta)+(A(s,a;\theta,\alpha)-\max_{a'\epsilon|\hat{A}|}A(s,a';\theta,\alpha))

这样以来就可以通过下式得到唯一的值函数:

a=argmaxaϵA^Q(s,a;θ,α,β)=argmaxaϵA^A(s,a;θ,α)Q(s,a;θ,α,β)=V(s;θ,β)\begin{gathered}a^*=\arg\max_{a'\epsilon\hat{A}}Q(s,a';\theta,\alpha,\beta)=\arg\max_{a'\epsilon\hat{A}}A(s,a';\theta,\alpha) \\Q(s,a*;\theta,\alpha,\beta)=V(s;\theta,\beta)\end{gathered}

此外还可以采取使用优势函数的平均值替代maxmax部分,即:

Q(s,a;θ,α,β)=V(s;θ,β)+(A(s,a;θ,α)1A^aA(s,a;θ,α))Q(s,a;\theta,\alpha,\beta)=V(s;\theta,\beta)+(A(s,a;\theta,\alpha)-\frac{1}{|\hat{A}|}\sum_{a'}A(s,a';\theta,\alpha))

虽然这种方法,使得VV值和AA值的语义不是那么的清晰,但是比较稳定。

Dueling DQN主要是修改了网络结构,其他部分的流程是一样的,在此不再赘述。

坚持原创技术分享,您的支持将鼓励我继续创作!
  • 本文作者:bdqfork
  • 本文链接:/articles/50
  • 版权声明:本博客所有文章除特别声明外,均采用BY-NC-SA 许可协议。转载请注明出处!
表情 |预览
快来做第一个评论的人吧~