CTC 的两种解码方法

CTC loss 应用于图像文字识别的训练过程中。

在预测过程中,当输入 x x x,我们希望能够得到使得 p ( l ∣ x ) p\left( l | x \right) p(lx)概率最大的标签 l l l。在序列学习问题中,这个问题被称为解码,在有限的时间内得到条件概率最大的序列$l^{*} $。
l ∗ = a r g m a x p ( l ∣ x ) l^{*} = argmax p\left( l | x \right) l=argmaxp(lx)

假设有字符列表 ( ′ − ′ , ′ A ′ , ′ B ′ ) \left( '-', 'A', 'B'\right) (,A,B),时刻 T = 3 T=3 T=3,并且定义在 t t t时刻时,字符 c c c出现的概率为 P ( c , t ) P\left( c, t\right) P(c,t)。如下表所示,以横轴作为时刻序列,纵轴为字符列表,表格中的数字为概率,我们的目标是在这个二维空间中搜索出概率最大的标签$l^{*} $。

CTC 的两种解码方法

greedy decode

贪心的思想是每次都要最好的,那也就是说每次选取当前时刻的最大概率的字符,最后 T T T个字符串成一个标签。如下图所示, T = 1 T=1 T=1时, P ( ’ − ‘ , 1 ) P\left( ’-‘, 1\right) P(,1)概率最大; T = 2 T=2 T=2时, P ( ’ − ‘ , 2 ) P\left( ’-‘, 2\right) P(,2)概率最大; T = 3 T=3 T=3时, P ( ’ − ‘ , 3 ) P\left( ’-‘, 3\right) P(,3)概率最大。最后的输出标签为 " b l a n k " "blank" "blank" p ( l = " b l a n k " ∣ x ) = 0.5 × 0.4 × 0.6 = 0.12 p\left( l = "blank" | x \right) =0.5 \times 0.4 \times 0.6 = 0.12 p(l="blank"x)=0.5×0.4×0.6=0.12

CTC 的两种解码方法外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-hfdgoWEL-1603511700249)(./greedy.png)]

贪心解码只考虑了一条路线,在CTC算法中,我们曾定义过一个多对一的映射 B B B,合并输出标签相同的路径。如下图所示,当输出标签为 " A " "A" "A"时,有三条路径 " A − − " "A--" "A", " − A − " "-A-" "A" " − − A " "--A" "A",表示为 B ( " A − − " ) = B ( " − A − " ) = B ( " − − A " ) = " A " B\left( "A--" \right) = B\left( "-A-" \right) = B\left( "--A" \right)="A" B("A")=B("A")=B("A")="A",那么 p ( l = " A " ∣ x ) p\left( l = "A" | x \right) p(l="A"x)应为这三条路径概率的总和。
p ( l = " A " ∣ x ) = p ( l = " A − − " ∣ x ) + p ( l = " − A − " ∣ x ) + p ( l = " − − A " ∣ x ) = 0.198 p\left( l = "A" | x \right) = p\left( l = "A--" | x \right) + p\left( l = "-A-" | x \right) + p\left( l = "--A" | x \right) = 0.198 p(l="A"x)=p(l="A"x)+p(l="A"x)+p(l="A"x)=0.198

p ( l = " A " ∣ x ) = 0.198 p\left( l = "A" | x \right) =0.198 p(l="A"x)=0.198远远大于 p ( l = " b l a n k " ∣ x ) = 0.12 p\left( l = "blank" | x \right) = 0.12 p(l="blank"x)=0.12 " A " "A" "A"更应该成为输出标签。

CTC 的两种解码方法

Beam Search

定义 t t t时刻网络输出序列对应的标签为 s s s的概率 P r ( s , t ) Pr \left( s, t \right) Pr(s,t),定义 P r − ( s , t ) Pr^{-} \left( s, t \right) Pr(s,t) t t t时刻输出空字符的概率, P r + ( s , t ) Pr^{+} \left( s, t \right) Pr+(s,t) t t t时刻输出非空字符的概率。那么 P r ( s , t ) = P r − ( s , t ) + P r + ( s , t ) Pr \left( s, t \right) = Pr^{-} \left( s, t \right) + Pr^{+} \left( s, t \right) Pr(s,t)=Pr(s,t)+Pr+(s,t)。Beam Search每一步搜索选取概率$Pr \left( s, t \right) $最大的W个节点进行扩展,W称为Beam Width。

下面的例子,我们选 W = 2 W=2 W=2。在 T = 0 T=0 T=0时刻,标签为空。

T = 1 T=1 T=1时,标签 " A " , " B " , " b l a n k " "A", "B", "blank" "A","B","blank"的概率如下:
P r ( " A " , 1 ) = 0.2 P r ( " B " , 1 ) = 0.3 P r ( " b l a n k " , 1 ) = 0.5 Pr \left( "A" , 1 \right) = 0.2 \\ Pr \left( "B", 1 \right) = 0.3 \\ Pr \left( "blank", 1 \right) = 0.5 Pr("A",1)=0.2Pr("B",1)=0.3Pr("blank",1)=0.5
标签 " B " "B" "B" " b l a n k " "blank" "blank"的概率最高,以标签 ′ ′ A ′ ′ ''A'' A " B " "B" "B"进行下一步扩展。

T = 2 T=2 T=2,标签 " B B " "BB" "BB"出现的概率为:
P r − ( " B B " , 2 ) = 0 P r + ( " B B " , 2 ) = P r ( " B " , 1 ) ∗ P ( ′ B ′ , 2 ) = 0.09 Pr^{-} \left( "BB", 2 \right) = 0 \\ Pr^{+} \left( "BB", 2 \right) = Pr \left( "B" , 1 \right) * P\left( 'B', 2\right)= 0.09 Pr("BB",2)=0Pr+("BB",2)=Pr("B",1)P(B,2)=0.09
P r ( " B B " , 2 ) = P r − ( " B B " , 2 ) + P r + ( " B B " , 2 ) = 0.09 Pr \left( "BB", 2 \right) = Pr^{-} \left( "BB", 2 \right) + Pr^{+} \left( "BB", 2 \right) = 0.09 Pr("BB",2)=Pr("BB",2)+Pr+("BB",2)=0.09
同理,可计算标签 " B A " , " b l a n k " "BA","blank" "BA""blank"的概率。

然而,当 T = 2 T=2 T=2,要计算标签 " A " "A" "A"的概率时,字符 ′ − ′ '-' ′ A ′ 'A' A都有出现的可能, P r − ( " A " , 2 ) Pr^{-} \left( "A", 2 \right) Pr("A",2)不为零,因此标签 " A " "A" "A"出现的概率计算如下:
P r − ( " A " , 2 ) = P r ( " A " , 1 ) ∗ P ( ′ − ′ , 2 ) = 0.08 P r + ( " A " , 2 ) = P r ( " b l a n k " , 1 ) ∗ P ( ′ A ′ , 2 ) = 0.15 Pr^{-} \left( "A", 2 \right) = Pr \left( "A" , 1 \right) * P\left( '-', 2\right)= 0.08 \\ Pr^{+} \left( "A", 2 \right) = Pr \left( "blank" , 1 \right) * P\left( 'A', 2\right)= 0.15 Pr("A",2)=Pr("A",1)P(,2)=0.08Pr+("A",2)=Pr("blank",1)P(A,2)=0.15
P r ( " A " , 2 ) = P r + ( " A " , 2 ) + P r − ( " A " , 2 ) = 0.23 Pr \left( "A", 2 \right) = Pr^{+} \left( "A", 2 \right) + Pr^{-} \left( "A", 2 \right) = 0.23 Pr("A",2)=Pr+("A",2)+Pr("A",2)=0.23

同理可计算 P r ( " B " , 2 ) = 0.27 Pr \left( "B", 2 \right) = 0.27 Pr("B",2)=0.27

标签 " A " "A" "A" " B " "B" "B"的概率最高,以标签’ ′ A ′ ′ 'A'' A " B " "B" "B"进行下一步扩展。

按照上面的计算公式,当 T = 3 T=3 T=3时,计算出标签 " A " "A" "A"的概率 0.198 0.198 0.198最高,那么就以 " A " "A" "A"为输出标签。

CTC 的两种解码方法

W = 1 W=1 W=1时,beam search 就是 greedy decode。

Tensorflow 函数

tf.nn.ctc_greedy_decoder(
inputs,
sequence_length,
merge_repeated=True
)

tf.nn.ctc_beam_search_decoder(
inputs,
sequence_length,
beam_width=100,
top_paths=1,
merge_repeated=True
)

参考

  1. Supervised Sequence Labelling with Recurrent Neural Networks
  2. https://xiaodu.io/ctc-explained-part2/