[论文阅读] CTPN---Detecting Text in Natural Image with Connectionist Text Proposal Network

  • 这篇论文思路和Faster RCNN是差不多的。总体来说,就是先通过RPN(Region Proposal Network)来提取proposal,然后再对提取得到的proposal进行classification。
  • 文章对Faster RCNN有以下几点改进:
    • Faster RCNN中使用的3种size和3种长宽比组合的9种anchor,但是CTPN中,他固定了anchor为16px(vgg16, 因为有4个pooling 层),而只是设置了10种高的值。这样是结合了text detection的特点,一般都是细长的结构。
    • 再得到Feature map之后,我们通过一个BD-LSTM结构去提取每个pixel对应的Feature。这样做是为了利用global information。他将一行的pixel看成一个序列输入给BD-LSTM去提取Feature。得到BD-LSTM的输出以后,我们再去得到每个anchor的score以及对应的anchor的坐标值。
    • 还有一个contribution是他对水平坐标还做了一定的微调。具体的公式如下所示:
      o=(xsidecxa)/wa,o=(xsidecxa)/wa

      这里面o代表的predict,o代表的是GT。xside代表的是未修正的预测的anchor的坐标,xside代表的就是ground truth。cxa代表anchor的对心所对应的x坐标。wa代表anchor的宽,这里是固定值(16)。之所以除以宽相对于做了一定的归一化吧。

[论文阅读] CTPN---Detecting Text in Natural Image with Connectionist Text Proposal Network

  • 算法的流程:如图上所示:
    • 首先通过常规的特征提取模块(例如,VGG16)来得到feature map,假设大小为h*w*c
    • 通过一个卷积层,将其转化为h*w*256的shape
    • 我们将其转化为h*(w*256),其中,将w*256看成一个长度为w的输入序列,将其输入到BD-LSTM中。
    • 将得到feature 再转化成h*w*d 其中d代表的是BD-LSTM输出的维度
    • 然后再分别通过全连接层来对每个anchor预测score以及坐标。注意,这里是对feature map中的每个pixel进行预测的。也就是说我们fc的输出分别是h*w*(10*2)以及h*w*(10*4)
    • 最后使用上面步骤训练好的网络,得到类似与上图B的许多anchor,然后在使用连接算法,将其连接起来。连接算法的定义如下:
      • 首先挑选出所有score>0.7的anchor
      • 针对每个anchor, Bi定义他的邻居anchorBj, 他们要满足以下条件
        • 这两个anchor的最近的
        • anchor之间的距离小于50个pixel
      • 如果BiBj互为邻居,那么就将其合并,知道找不到互为邻居的anchor为止。
  • 上面讲了算法的流程,接下来我们看一下loss的定义,来了解具体我们怎么训练我们的网络

    L(si,vj,ok)=1NsiLscl(si,si)+λ1NvjLvre(vj,vj)+λ2NokLore(ok,ok)

    • 上面是训练的整体的loss,它由三部分组成,第一部分是分类的交叉熵,第二部分是对垂直坐标做regression的Smooth L1loss,第三部分是对水平坐标做regression的Smooth L1 loss。
    • si代表的是第i个anchor预测是text的概率,si是对应的ground truth{0,1}。
    • vj,vjjanchorgroundtruthjiprobability>0.7s_j^*=1$的anchor,也就是只计算正样本
    • Smooth L1 loss的定义如下:
      var=win(xiyi)i{x,y,w,h}smoothi={12vari2σ2|xσ|<1|vari|σ20.5otherwiseSmoothL1Loss=ismoothi
    • 这里使用Smooth L1 loss主要是因为:L1 loss会产生更稀疏的矩阵,L2 loss会产生更平滑的矩阵(原因)。但是有一个问题是L2如果我们var太大的话,还产生梯度爆炸,所以我们在这里使用了Smooth L1 Loss,他其实是一个分段的函数。如果var较小的话,我们使用L2 loss,否则,我们使用L1 loss。
  • 实现代码:CTPN

  • 有问题,欢迎探讨