OCR EAST: An Efficient and Accurate Scene Text Detector 自然场景下的文字识别算法详解

最近研究OCR,有篇比较好的算法文章,《EAST: An Efficient and Accurate Scene Text Detector》,该文发表在2017年CVPR上。代码地址:https://github.com/argman/EAST ,这是原作者参与的一份tensorflow版本代码,网上还有其他的实现。

下面根据原文的结构和上述提供的代码详细的解读一下该算法

一、网络架构

文中使用了PVANet和VGG16,下图1是原文的网络结构图(PVANet)
网络输入一张图片,经过四个阶段的卷积层可以得到四张feature map, 分别为f4,f3,f2,f1f_{4},f_{3},f_{2},f_{1},它们相对于输入图片分别缩小14,18,116,132\frac{1}{4},\frac{1}{8},\frac{1}{16},\frac{1}{32},之后使用上采样、concat(串联)、卷积操作依次得到h1,h2,h3,h4h_{1},h_{2},h_{3},h_{4},在得到h4h_{4}这个融合的feature map后,使用大小为3×33\times3通道数为32的卷积核卷积得到最终的feature map。

文中对文本框的定义有两种,一种是旋转矩形(RBOX),另一种是四边形(QUAD)。因为代码只实现了RBOX,所以下面也只对RBOX框进行分析

得到最终的feature map后,使用一个大小为1×11\times1通道数为1的卷积核得到一张score map用FsF_{s}表示。在feature map上使用一个大小为1×11\times1通道数为4的卷积核得到text boxes,使用一个大小为1×11\times1通道数为1的卷积核得到text rotation angle,这里text boxes和text rotation angle合起来称为geometry map用FgF_{g}表示。

关于上述的Fs,FgF_{s},F_{g}要说明几点(如下图2所示):

  • FsF_{s}大小为原图的14\frac{1}{4}通道数为1,每个像素表示对应于原图中像素为文字的概率值,所以值在[0,1]范围内。
  • FgF_{g}大小也为原图的14\frac{1}{4}通道数为5,即4+1(text boxes + text rotation angle)。
  • text boxes通道数为4,其中text boxes每个像素如果对应原图中该像素为文字,四个通道分别表示该像素点到文本框的四条边的距离,范围定义为输入图像大小,如果输入图像为512,那范围就是[0,512]。下图2d表示
  • text rotation angle通道数为1,其中text rotation angle每个像素如果对应原图中该像素为文字,该像素所在框的倾斜角度,角度范围定义为[-45,45]度。下图2e表示

OCR EAST: An Efficient and Accurate Scene Text Detector 自然场景下的文字识别算法详解
OCR EAST: An Efficient and Accurate Scene Text Detector 自然场景下的文字识别算法详解

二、关于训练标签的生成

如上可知,训练标签由两个部分组成,一个是score map的标签,一个是geometry map标签。
注意:程序要求输入的四边形标定点是以顺时针方向标定的,这点很重要

1. score map标签的生成方法

  • 首先生成一个与图片大小一样的矩阵,值都为0
  • 根据标定好的四边形框对该四边形框进行缩小,缩小方法下面会详细说明,得到最终结果如上图2a中的绿框
  • 将绿框中的像素赋值1表示正样本的score,其他为负样本的score
  • 最后按照每隔4个像素采样,得到图片14\frac{1}{4}大小的score map

上述缩小四边形的方法:

  • 首先定义四个顶点Q={piiϵ{1,2,3,4}}Q = \{p_{i}|i \epsilon \{1, 2, 3, 4\}\},这四个顶点按照顺时针方向排列
  • 计算缩小的参考大小如下式所示,下式表示的是选取与顶点相连的两条边中最小的边的大小记为rir_{i}
    ri=min(D(pi,p(imod  4)+1),D(pi,p((i+2)mod  4)+1))r_{i}=min(D(p_{i},p_{(i\mod4)+1}),D(p_{i},p_{((i+2)\mod4)+1}))
    其中D(pi,pj)D(p_{i},p_{j})表示点pip_{i}pjp_{j}之间的距离
  • 对于边$ p_{i}, p_{(i mod 4)+1}$,缩小0.3ri0.3r_{i}0.3r(imod  4)+10.3r_{(i\mod4)+1}的和的像素大小

2. geometry map标签的生成方法

  • 首先生成一个与图片大小一样的5通道矩阵用来制作text boxes 与 text rotation angle
  • 根据标定的四变形生成一个面积最小的平行四边形,进而得到平行四边形的外界旋转矩形
  • 根据旋转矩形的四个点坐标,可以选择出y值最大的坐标顶点和该顶点逆时针方向的顶点(也可以称该顶点右边的顶点),根据这两个点的连线可以求出连线与x轴的夹角,这个夹角取值在(0,90)度之间,称这个夹角为angle
  • 当angle<45度时,定义y值最大的点为p3p_{3}点,其它点按顺时针方向依次类推。当angle>45度时,定义y值最大的点为p2p_{2}点,此时angle角变换为(π/2angle)-(\pi/2 - angle),这样就保证了angle角度[-45,45]度
  • 上述还有一种特殊情况要考虑,当y值最大的点有两个时,说明矩形与x轴平行,angle定义为0度,这时候将x与y坐标相加最小的点定义为p0p_{0}点,其它点依次类推
  • 根据得到的旋转矩形和angle值将geometry map的五个通道赋值,赋值方法为,对于text boxes的四个通道,每个通道表示图像中的像素点坐标到旋转矩形的四个边的距离顺序为,0通道表示点到p0p_{0}p1p_{1}边的距离,1通道表示点到p1p_{1}p1p_{1}边的距离,按照顺时针依次赋值四个通道,也分别称为到top、right、bottom、left边的距离,对于text rotation angle这一个通道,将旋转矩形中所有像素都赋值上述计算出的angle大小
  • 最后得到的五个通道按照每隔4个像素采样,这样就可以得到图片14\frac{1}{4}大小的geometry map了

三、损失函数的定义

损失函数定义如下
L=Ls+λgLgL = L_{s} + \lambda_{g}L_{g}
其中LsL_{s}LgL_{g}分别表示score map和geometry map的损失, λg\lambda_{g}表示两个损失的权重,文章设为1

1. score map的损失计算
这里要说明的是文章采用的是交叉熵计算该损失,但是程序实现没有采用,程序采用的是dice loss

Ls=12yspsys+psL_{s}=1-\frac{2y_{s}p_{s}}{y_{s}+p_{s}}
其中ysy_{s}代表位置敏感图像分割(position-sensitive segmentation)的label,psp_{s}代表预测的分割值

2. geometry map的损失计算
采用IoU loss,计算方法如下
Lg=LAABB+λθLθL_{g} = L_{AABB} + \lambda_{\theta}L_{\theta}
其中λθ=10\lambda_{\theta}=10

  • LAABB=logIoU(R^,R)=logR^RR^RL_{AABB}=-logIoU(\hat{R},R)=-log\frac{|\hat{R}\bigcap R^{*}|}{|{\hat{R}\bigcup R^{*}}|}
    其中,R^\hat{R}表示预测, RR^{*}表示真实值
    R^R=wihi|\hat{R}\bigcap R^{*}|=w_{i}*h_{i}计算可以通过下述方法
    wi=min(d2^,d2)+min(d4^,d4)w_{i}=min(\hat{d_{2}}, d^{*}_{2})+min(\hat{d_{4}}, d^{*}_{4})
    hi=min(d1^,d1)+min(d3^,d3)h_{i}=min(\hat{d_{1}}, d^{*}_{1})+min(\hat{d_{3}}, d^{*}_{3})
    其中d1,d2,d3,d4d_{1},d_{2},d_{3},d_{4}表示点到top、right、bottom、left边的距离。
    R^R=R^+RR^R|{\hat{R}\bigcup R^{*}}|=|\hat{R}|+|R^{*}|-|\hat{R}\bigcap R^{*}|

  • Lθ(θ^,θ)=1cos(θ^θ)L_{\theta}(\hat{\theta}, \theta^{*})=1-cos(\hat{\theta}-\theta^{*}),其中θ\theta^{*}表示预测值,θ^\hat{\theta}表示真实值

最后文章还提出了Locality-Aware NMS,感觉就是先合并一次窗口,然后采用标准的NMS去抑制窗口,详细可以看代码实现,采用的是c++实现的