RL论文阅读7 - MAML2017
Tittle
标签
- meta-learning
- framework
总结
meta-learning的目标就是训练一个模型,使这个模型能够从很少的新任务的数据中快速学习一个新的任务。这个模型的训练需要大量的不同任务作为数据。
提出了一种meta-learning的框架,能够用于使用梯度下降的算法,使其在应用于新的任务时,只需要很少步骤的训练就能够达到较好的效果。这个框架能够用于分类任务(如图像)和使用梯度下降来训练策略的强化学习的任务。
其实简单来说,就是训练了适应一些列某类的任务的模型网络,当有该类新任务时,只需要在这个模型上进行参数微调。
特点:
- 能够从较少的examples中快速学习
- 随着数据量的增多,能够继续增加算法的适应性
原理概述
一些标记:
- 模型 :
- 任务 :
- 损失函数
- 初始状态分布
- 状态转换概率分布
- H: episode长度(多少步)
模型训练
希望让模型的参数处于对任务改变的敏感点,这样任务微小的改变,都能引起很大的loss function改变,然后使用这个方向对特定任务进行更新。如下图:
适应参数训练
模型的参数为。当这个模型去适应一个新的任务KaTeX parse error: Undefined control sequence: \T at position 1: \̲T̲_i,那么通过若*梯度下降,就能够得到针对这个任务的适应参数。使用下面这个更新公式计算(以一步gradient为例,多步同理):
就是继续利用的损失函数继续优化。
是学习率
模型参数训练
采样一些任务tasks,这些任务服从分布
然后先计算每个任务的适应参数
和它的损失,然后最小化采样任务的所有损失和来更新模型参数
注意这里计算的某个任务的损失,使用的是已经进行适应该任务的模型,而不是通用模型
使用随机梯度下降(SGD),那么的更新就表示为:
是另一个学习率
算法描述
应用到回归和分类问题
算法描述
注意事项:
- 定义模型的H=1,丢弃了时间步,因此模型是一个输入对应一个输出,而不是序列输出输出
- 任务认为独立同分布
- 回归问题损失函数使用MSE
- 分类为题使用交叉熵损失函数:
应用到RL问题
算法描述
注意事项:
- RL的对于任务的损失函数如下:
- 定义R为非负, Loss之所以有负号是在RL中我们希望奖励值最大,由于使用的是梯度下降算法,加一个负号相当于梯度上升了,向着最大的饿方向。
- 对于step8,由于策略梯度算法是on-policy算法,所以需要使用当前的适应过的策略来采样新的数据。