论文笔记2:Deep Attention Recurrent Q-Network

参考文献:[1512.01693] Deep Attention Recurrent Q-Network (本篇DARQN)

[1507.06527v3] Deep Recurrent Q-Learning for Partially Observable MDPs(DRQN,可参见我上一篇笔记)

目前网上我搜到的论文笔记参考:论文笔记之:Deep Attention Recurrent Q-Network


创新点:将DQN(其实是更进一步的DRQN)与attention mechanism(注意力机制)结合

改进:基于DRQN,在CNN与LSTM之间加入了attention network(注意力网络)(作者也说这里其实可以看做LSTM额外增加了一个过滤门)

改进原因

1、DQN在需要4帧以上的图像时效果不好,并且DRQN没有太大的系统上的性能提升。

2、DQN训练时间长,参数太多

带来益处

1、可以通过高亮可视化agent正在关注的图像区域。

2、虽然没有在所有游戏上性能表现好,但是参数变少,实现加快训练速度。(emmm个人感觉还是没有在系统上效果提升)


Abstract

DRQN引入attention机制提出DARQN,建立的内置attention机制可以通过高亮显示agent正在关注的游戏屏幕区域,实现在线监测训练过程。

Introduction

提出改进原因:(前面写了,为完整性copy一下)

1、DQN在需要4帧以上的图像时效果不好,并且DRQN没有太大的系统上的性能提升。

2、DQN训练时间长,参数太多

虽然在训练时间问题上,前人提出了一种并行算法来提升训练速度,但作者认为并不是最有效的,而近年来visual attention model在标题生成,对象跟踪,机器翻译等领域取得进展,引发作者想要将这个attention machinism加入到DRQN中,主要的优点:注意到agent关注的输入图像中的相关的较小的信息区域,帮助减少整个结构的参数。

对比于DRQN其不同在于LSTM层不仅将数据用于为下个动作做出决策,也用于选出下一个注意的区域

DARQN

论文笔记2:Deep Attention Recurrent Q-Network

结构解释:

1、 论文笔记2:Deep Attention Recurrent Q-Network 状态下CNN接收视觉图像(visual frame),产生D feature maps,且每一个m*m维

2、attention network(g)将这个maps转换成向量 论文笔记2:Deep Attention Recurrent Q-Network ,这个向量中的每一个元素D维,一共有m*m个元素。输出为他们的线性组合形式 论文笔记2:Deep Attention Recurrent Q-Network ,称为context vector

3、LSTM接收这个context vector 论文笔记2:Deep Attention Recurrent Q-Network ,之前的隐藏状态 论文笔记2:Deep Attention Recurrent Q-Network 和记忆库中选取的状态 论文笔记2:Deep Attention Recurrent Q-Network ,产生 论文笔记2:Deep Attention Recurrent Q-Network 用于计算Q值(图中①),和用于产生下一状态 论文笔记2:Deep Attention Recurrent Q-Network 向量(图中②)

提出两种用于计算context vector的方法:soft attention和hard attention:两种方式的训练算法以及参数更新方式不同(没太细研读。。。)

1、soft attention:2个fc层+softmax activation

2、hard attention:略,可参加上面链接的其他人写的论文笔记

Experiments

用五个游戏做的可以统计reward,100个回合做平均,每一个回合进行50000步。

attention network通过4个步骤来训练(没看soft和hard具体实现,不太懂unroll step)。DRQN的权重参数每次都更新

DQN和DARQN每四步更新一次参数。

论文笔记2:Deep Attention Recurrent Q-Network其实发现DARQN效果并不太好,很少的游戏提升大

作者特别列出两个游戏:

论文笔记2:Deep Attention Recurrent Q-Network

作者还做了个实验,将DARQN的unroll step提升,发现提升后效果会变好:

论文笔记2:Deep Attention Recurrent Q-Network

Results

作者进行了实验分析:

1、在Seaquest这个游戏中hard比soft性能差的原因是采用policy gradient进行参数更新容易导致局部最优解,与这个游戏玩法设置有关系

2、在Breakout这个游戏中DARQN性能差的原因是因为设置的unroll step数过小,设置到10 以后(参见上面最后一个图),性能提升。

3、最后实验可以可视化,发现agent注意的高亮点随着agent在游戏中所关注的不同,亮点会变化:

论文笔记2:Deep Attention Recurrent Q-Network

Conclution

提出未来研究方向:

1、test different techniques for reducing stochastic gradient variability

2、to apply different approaches to training stochastic attention networks



由于本人资历尚浅,只为帮助更多跟我一样的小白,所以有错误请您指出,谢谢