「Deep Learning」Note on Spatial Transformer Networks

Sina Weibo:小锋子Shawn
Tencent E-mail:[email protected]
http://blog.****.net/dgyuanshaofeng/article/details/79875310

论文笔记

空间变换器网络(Spatial Transformer Networks, STN)是谷歌DeepMind的杰作。

卷积网络(ConvNet)缺乏对输入数据的空间不变性,即图像发生仿射变换后,经过网络后将得到不一致的结果。我们经常需要进行数据扩充(Data augmentation),这种处理虽然针对不同视觉任务,会有不同的具体步骤,但是往往是盲目的。针对这个问题,Jaderberg等人提出可学习/可微分模块,空间变换器,使得卷积网络具有这些不变性:平移、尺度、旋转和翘曲(warping)。因此,有两种方法提供空间不变性:1、数据扩充;2、STN。

局部最大值池化层在一定程度上提供微小空间不变性,但是对于较大的仿射变换则不适合。因此,存在局限。

论文中提到的适用范围(不限于):图像分类,尤其是自然图像中的文本处理;互定位(co-localisation),尤其是医学图像中某器官被配准在某轴(长轴)上,便于后面的识别和分割;空间注意力。

空间变换器

如图1所示。空间变换器由定位网络(Localisation network)、网格生成器(Grid generator)和采样器(Sampler)三部分组成。U为输入特征图或者输入图像(彩色或者灰色),首先经过定位网络,输出变换参数θ,然后网格生成器根据θU的大小生成网格Tθ(G)(在pytorch中,由torch.nn.functional.affine_grid实现),G为regular网格,Tθ()为不regular网格,最后采样器利用网格Tθ()U输出变形特征图或者变形输入图像V(在pytorch中,由torch.nn.functional.grid_sample实现)。定位网络接收不同输入,输出不同变换参数θ,即变换参数以输入为条件。

「Deep Learning」Note on Spatial Transformer Networks

图 1

定位网络

没啥好说的,就是一个回归网络。

参数化采样网格

论文中,定义输出像素落入regular网格G={Gi}里,其中Gi=(xit,yit),表示输出特征图V。假设Tθ为仿射变换,则满足:

(xisyis)=Tθ(Gi)=Aθ(xityit1)=[θ11θ12θ13θ21θ22θ23](xityit1)

可微图像采样

没啥好说的(原文好复杂),就是一个图像插值过程。

实验结果

数据集:MNIST、SVHN(街景视角房屋号码)和CUB-200-2011(200类鸟)

Distorted MNIST

这部分说明,TPS变换最牛逼。

SVHN

这部分说明,自然图像的文本检测和识别可用STN。

细粒度分类

这部分说明,多个STN会成为不同部件(鸟头、鸟身)检测器,如图2所示。启发行人再识别研究。

「Deep Learning」Note on Spatial Transformer Networks

图 2