CTC 的两种解码方法
CTC loss 应用于图像文字识别的训练过程中。
在预测过程中,当输入
x
x
x,我们希望能够得到使得
p
(
l
∣
x
)
p\left( l | x \right)
p(l∣x)概率最大的标签
l
l
l。在序列学习问题中,这个问题被称为解码,在有限的时间内得到条件概率最大的序列$l^{*} $。
l
∗
=
a
r
g
m
a
x
p
(
l
∣
x
)
l^{*} = argmax p\left( l | x \right)
l∗=argmaxp(l∣x)
假设有字符列表 ( ′ − ′ , ′ A ′ , ′ B ′ ) \left( '-', 'A', 'B'\right) (′−′,′A′,′B′),时刻 T = 3 T=3 T=3,并且定义在 t t t时刻时,字符 c c c出现的概率为 P ( c , t ) P\left( c, t\right) P(c,t)。如下表所示,以横轴作为时刻序列,纵轴为字符列表,表格中的数字为概率,我们的目标是在这个二维空间中搜索出概率最大的标签$l^{*} $。
greedy decode
贪心的思想是每次都要最好的,那也就是说每次选取当前时刻的最大概率的字符,最后 T T T个字符串成一个标签。如下图所示, T = 1 T=1 T=1时, P ( ’ − ‘ , 1 ) P\left( ’-‘, 1\right) P(’−‘,1)概率最大; T = 2 T=2 T=2时, P ( ’ − ‘ , 2 ) P\left( ’-‘, 2\right) P(’−‘,2)概率最大; T = 3 T=3 T=3时, P ( ’ − ‘ , 3 ) P\left( ’-‘, 3\right) P(’−‘,3)概率最大。最后的输出标签为 " b l a n k " "blank" "blank", p ( l = " b l a n k " ∣ x ) = 0.5 × 0.4 × 0.6 = 0.12 p\left( l = "blank" | x \right) =0.5 \times 0.4 \times 0.6 = 0.12 p(l="blank"∣x)=0.5×0.4×0.6=0.12。
外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-hfdgoWEL-1603511700249)(./greedy.png)]
贪心解码只考虑了一条路线,在CTC算法中,我们曾定义过一个多对一的映射
B
B
B,合并输出标签相同的路径。如下图所示,当输出标签为
"
A
"
"A"
"A"时,有三条路径
"
A
−
−
"
"A--"
"A−−",
"
−
A
−
"
"-A-"
"−A−"和
"
−
−
A
"
"--A"
"−−A",表示为
B
(
"
A
−
−
"
)
=
B
(
"
−
A
−
"
)
=
B
(
"
−
−
A
"
)
=
"
A
"
B\left( "A--" \right) = B\left( "-A-" \right) = B\left( "--A" \right)="A"
B("A−−")=B("−A−")=B("−−A")="A",那么
p
(
l
=
"
A
"
∣
x
)
p\left( l = "A" | x \right)
p(l="A"∣x)应为这三条路径概率的总和。
p
(
l
=
"
A
"
∣
x
)
=
p
(
l
=
"
A
−
−
"
∣
x
)
+
p
(
l
=
"
−
A
−
"
∣
x
)
+
p
(
l
=
"
−
−
A
"
∣
x
)
=
0.198
p\left( l = "A" | x \right) = p\left( l = "A--" | x \right) + p\left( l = "-A-" | x \right) + p\left( l = "--A" | x \right) = 0.198
p(l="A"∣x)=p(l="A−−"∣x)+p(l="−A−"∣x)+p(l="−−A"∣x)=0.198
p ( l = " A " ∣ x ) = 0.198 p\left( l = "A" | x \right) =0.198 p(l="A"∣x)=0.198远远大于 p ( l = " b l a n k " ∣ x ) = 0.12 p\left( l = "blank" | x \right) = 0.12 p(l="blank"∣x)=0.12, " A " "A" "A"更应该成为输出标签。
Beam Search
定义 t t t时刻网络输出序列对应的标签为 s s s的概率 P r ( s , t ) Pr \left( s, t \right) Pr(s,t),定义 P r − ( s , t ) Pr^{-} \left( s, t \right) Pr−(s,t)为 t t t时刻输出空字符的概率, P r + ( s , t ) Pr^{+} \left( s, t \right) Pr+(s,t)为 t t t时刻输出非空字符的概率。那么 P r ( s , t ) = P r − ( s , t ) + P r + ( s , t ) Pr \left( s, t \right) = Pr^{-} \left( s, t \right) + Pr^{+} \left( s, t \right) Pr(s,t)=Pr−(s,t)+Pr+(s,t)。Beam Search每一步搜索选取概率$Pr \left( s, t \right) $最大的W个节点进行扩展,W称为Beam Width。
下面的例子,我们选 W = 2 W=2 W=2。在 T = 0 T=0 T=0时刻,标签为空。
在
T
=
1
T=1
T=1时,标签
"
A
"
,
"
B
"
,
"
b
l
a
n
k
"
"A", "B", "blank"
"A","B","blank"的概率如下:
P
r
(
"
A
"
,
1
)
=
0.2
P
r
(
"
B
"
,
1
)
=
0.3
P
r
(
"
b
l
a
n
k
"
,
1
)
=
0.5
Pr \left( "A" , 1 \right) = 0.2 \\ Pr \left( "B", 1 \right) = 0.3 \\ Pr \left( "blank", 1 \right) = 0.5
Pr("A",1)=0.2Pr("B",1)=0.3Pr("blank",1)=0.5
标签
"
B
"
"B"
"B"和
"
b
l
a
n
k
"
"blank"
"blank"的概率最高,以标签
′
′
A
′
′
''A''
′′A′′和
"
B
"
"B"
"B"进行下一步扩展。
当
T
=
2
T=2
T=2,标签
"
B
B
"
"BB"
"BB"出现的概率为:
P
r
−
(
"
B
B
"
,
2
)
=
0
P
r
+
(
"
B
B
"
,
2
)
=
P
r
(
"
B
"
,
1
)
∗
P
(
′
B
′
,
2
)
=
0.09
Pr^{-} \left( "BB", 2 \right) = 0 \\ Pr^{+} \left( "BB", 2 \right) = Pr \left( "B" , 1 \right) * P\left( 'B', 2\right)= 0.09
Pr−("BB",2)=0Pr+("BB",2)=Pr("B",1)∗P(′B′,2)=0.09
P
r
(
"
B
B
"
,
2
)
=
P
r
−
(
"
B
B
"
,
2
)
+
P
r
+
(
"
B
B
"
,
2
)
=
0.09
Pr \left( "BB", 2 \right) = Pr^{-} \left( "BB", 2 \right) + Pr^{+} \left( "BB", 2 \right) = 0.09
Pr("BB",2)=Pr−("BB",2)+Pr+("BB",2)=0.09
同理,可计算标签
"
B
A
"
,
"
b
l
a
n
k
"
"BA","blank"
"BA","blank"的概率。
然而,当
T
=
2
T=2
T=2,要计算标签
"
A
"
"A"
"A"的概率时,字符
′
−
′
'-'
′−′和
′
A
′
'A'
′A′都有出现的可能,
P
r
−
(
"
A
"
,
2
)
Pr^{-} \left( "A", 2 \right)
Pr−("A",2)不为零,因此标签
"
A
"
"A"
"A"出现的概率计算如下:
P
r
−
(
"
A
"
,
2
)
=
P
r
(
"
A
"
,
1
)
∗
P
(
′
−
′
,
2
)
=
0.08
P
r
+
(
"
A
"
,
2
)
=
P
r
(
"
b
l
a
n
k
"
,
1
)
∗
P
(
′
A
′
,
2
)
=
0.15
Pr^{-} \left( "A", 2 \right) = Pr \left( "A" , 1 \right) * P\left( '-', 2\right)= 0.08 \\ Pr^{+} \left( "A", 2 \right) = Pr \left( "blank" , 1 \right) * P\left( 'A', 2\right)= 0.15
Pr−("A",2)=Pr("A",1)∗P(′−′,2)=0.08Pr+("A",2)=Pr("blank",1)∗P(′A′,2)=0.15
则
P
r
(
"
A
"
,
2
)
=
P
r
+
(
"
A
"
,
2
)
+
P
r
−
(
"
A
"
,
2
)
=
0.23
Pr \left( "A", 2 \right) = Pr^{+} \left( "A", 2 \right) + Pr^{-} \left( "A", 2 \right) = 0.23
Pr("A",2)=Pr+("A",2)+Pr−("A",2)=0.23
同理可计算 P r ( " B " , 2 ) = 0.27 Pr \left( "B", 2 \right) = 0.27 Pr("B",2)=0.27
标签 " A " "A" "A"和 " B " "B" "B"的概率最高,以标签’ ′ A ′ ′ 'A'' ′A′′和 " B " "B" "B"进行下一步扩展。
按照上面的计算公式,当 T = 3 T=3 T=3时,计算出标签 " A " "A" "A"的概率 0.198 0.198 0.198最高,那么就以 " A " "A" "A"为输出标签。
当 W = 1 W=1 W=1时,beam search 就是 greedy decode。
Tensorflow 函数
tf.nn.ctc_greedy_decoder(
inputs,
sequence_length,
merge_repeated=True
)
tf.nn.ctc_beam_search_decoder(
inputs,
sequence_length,
beam_width=100,
top_paths=1,
merge_repeated=True
)
参考
- Supervised Sequence Labelling with Recurrent Neural Networks
- https://xiaodu.io/ctc-explained-part2/