Continual Learning 经典方法:Memory Aware Synapses (MAS)
1. 顾名思义
Synapses 是神经元的突触,在人脑中负责连接不同神经元结构。Hebb’s rule 表示在脑生理学中,突触连接常常满足 “Fire Together, Wire Together”,即同时被**或者同时失活。所以不同的任务对应潜在的不同突触——不同的记忆,因此选择**或者改变某些神经元突触即可称为 Memory Aware Synapses (MAS)。在基于深度模型的终身学习任务中来说,可以通过计算网络中神经元的重要性,来选择保持重要的神经元在终身学习过程中相对不变,而不重要的神经元可以有相对灵活的变化。如下图所示:
2. 核心问题:如何计算重要性 Importance
2.1 Recap Elastic Weight Consolidation (EWC) and Synaptic Intelligence (SI)
EWC 和 SI 是比较经典的计算网络参数不同重要性的方法,具体来说:
- EWC 通过估计 Fisher Information Matrix 的对角线值来计算参数重要性。
- SI 估计损失函数对于参数的敏感性来计算参数重要性。
然而上述方法都是需要基于金标准计算损失函数->反向传播,由此带来对重要性计算的 overestimated,而且个人认为最重要的是,这样基于损失函数容易陷入局部最小值,导致梯度消失的 complications。
2.2 基于输出敏感性的参数重要性估计
为了解决上述问题,MAS 采用输出函数的敏感性来估算参数重要性。
考虑小的扰动
δ
\delta
δ 对参数
θ
\theta
θ,导致了当前输出的改变
F
(
x
k
;
θ
+
δ
)
−
F
(
x
k
;
θ
)
≈
∑
i
,
j
g
i
j
(
x
k
)
δ
i
j
F\left(x_{k} ; \theta+\delta\right)-F\left(x_{k} ; \theta\right) \approx \sum_{i, j} g_{i j}\left(x_{k}\right) \delta_{i j}
F(xk;θ+δ)−F(xk;θ)≈∑i,jgij(xk)δij。其中
g
i
j
(
x
k
)
=
∂
(
F
(
x
k
;
θ
)
)
∂
θ
i
j
g_{i j}\left(x_{k}\right)=\frac{\partial\left(F\left(x_{k} ; \theta\right)\right)}{\partial \theta_{i j}}
gij(xk)=∂θij∂(F(xk;θ)) 是当前网络输出对于参数
θ
i
,
j
\theta_{i,j}
θi,j 的
δ
i
,
j
\delta_{i,j}
δi,j 扰动,在采样数据点
x
i
,
j
x_{i,j}
xi,j 处的导数。梯度
g
i
,
j
g_{i,j}
gi,j 用来计算对应参数
θ
i
,
j
\theta_{i,j}
θi,j 的重要性
Ω
i
j
\Omega_{i j}
Ωij—— 即很小的对参数的扰动能够造成模型输出改变。
Ω
i
j
=
1
N
∑
k
=
1
N
∥
g
i
j
(
x
k
)
∥
\Omega_{i j}=\frac{1}{N} \sum_{k=1}^{N}\left\|g_{i j}\left(x_{k}\right)\right\|
Ωij=N1∑k=1N∥gij(xk)∥
2.3 用重要性估计对模型正则化
L
(
θ
)
=
L
n
(
θ
)
+
λ
∑
i
,
j
Ω
i
j
(
θ
i
j
−
θ
i
j
∗
)
2
L(\theta)=L_{n}(\theta)+\lambda \sum_{i, j} \Omega_{i j}\left(\theta_{i j}-\theta_{i j}^{*}\right)^{2}
L(θ)=Ln(θ)+λ∑i,jΩij(θij−θij∗)2
当学习新的任务时,不仅需要最小化当前任务的损失函数
L
n
(
θ
)
L_{n}(\theta)
Ln(θ),还需要根据参数重要性控制特定参数的变化——> 进而达到用之前的数据正则化当前任务的目的。
3. 总结
文章的代码工程实现相对EWC和SI是更加简单的,直接把网络输出进行某种意义上的融合之后反向传播即可得到基于梯度的参数重要性估计。同时,论文从理论的角度分析了 MAS 和 Hebbian Learning 的联系,这也是很多终身学习论文的常用套路,用脑生理学知识用来解释网络的算法构架。