ctc loss and decoder
想要详细的了解ctc loss,建议直接转至大佬的博客 $\rightarrow$ https://xiaodu.io/ctc-explained/. 我个人写这个笔记的目的在于最近要对ctc loss进行魔改时,发现之前的细节都忘了。所以按照自己已有的基础上重新整理了一遍。除此之外,大佬的博客里面并没有代码解析。这里有ctc loss 和 ctc decode的python代码实现,所以想要对ctc loss进行魔改的,可以再过一遍我这篇文章~
Why ctc loss, ctc loss vs cross entropy
现实中有很多任务都可以看作是序列到序列的对齐训练。主要可以分为两类:
NLP领域常见的机器翻译和对话。对于这类任务,在训练阶段,我们通常使用cross-entropy + teacher forcing来训练模型。这类任务的特点是源序列和目标序列没有严格的对齐关系。他们本质上可以看作是 conditional language model. 也就是目标序列作为条件语言模型,更看重连贯性、流利度,其次是与源序列的对应关系(所以他们会有多样性研究)。
识别领域常见的语音识别,OCR,手语识别。对于这类任务,我们则主要使用ctc loss作为损失函数来训练模型。这类任务的特点是源序列和目标序列有严格的对齐关系。对于语音或手语,目标序列有语言模型的特点,但是更看重与源序列的准确的对应关系。
第二类任务其实也可以用cross entropy来训练,但是往往效果不太好。我们可以发现,对于第二类任务,最理想的情况是将源序列先进行分割,这样单独的对某一个音节,手语或者字符进行识别,准确率就会很高了。但是现实是,源序列更多的是未分割的情况。针对这类任务,[Alex Graves, 2006] 提出了Connectionist Temporal Classification.
使用ctc进行训练有两个要求:
- 源序列长度 >> 目标序列长度
- 源序列的order与目标序列的order一致,且存在顺序对齐的关系
ctc training
如何计算 ctc loss, 这篇博客CTC Algorithm Explained Part 1写的非常非常赞(以下简称为ctc explain blog),细节看这个就好了。
这里简单的概括下思想。 给定源序列 $X={x_1,x_2,..,x_T}$ 和目标序列 $Y={y_1,..y_N}$ . 其中 $T>>N$.
根据极大似然估计原理,我们的目标是找到使得 P(Y|X;W) 最大化的W。这里 X,Y 存在多对一的关系。
我们假设存在这样的路径 $\pi={\pi_1, …,\pi_T}$, $\pi_i\in |V’=V+blank|$ 与源序列一一对应(V是词表)。并且存在这样的映射关系 $\beta(\pi) = Y$ , 其映射法则就是去重和去掉blank(这个映射法则有其物理意义:就是一个音节或手语动作会包含多帧,以及存在中间停顿或无意义帧等情况.)
因此,优化的目标模型可以转换为:
$P(Y|X)=\sum_{\pi\in \beta^{-1}(Y)}P(\pi|X;W)$
所以我们的目标现在转换成了最大化满足 $\pi\in \beta^{-1}(Y)$ 的所有路径的概率。现在问题就转变成如何找到 Y 对应的所有路径。这是一个动态规划的问题。
Graves 根据HMM的前向后向算法,利用动态规划的方法来求解。根据目标序列Y和X的帧数构建一个表格:
纵轴是将目标序列扩展为前后都有blank的序列 $l’=(-, l_1, -, l_2,…,- ,l_N, -)$ 。如果 T=N 时,那么Y和X就是一一对应了。X的序列越长,这个表格的搜索空间越大,存在的可能的路径就越多。
如何找到所有的合法路径,先定义路径规则,然后找到递归条件便能通过动态规划的方法解决,具体细节参见ctc explain blog.
路径规则:
- 转换只能往右下方向,其他方向不允许
- 相同的字符之间至少有一个空格
- 非空字符不能被跳过(不然最终就不是apple了
- 起点必须从前两个字符开始
- 重点必须落在结尾两个字符
这里以前向算法为例来解释:
其中的符号:x表示输入序列,z表示输出序列,s表示纵轴的节点(2T+1个)
初始条件, t=1时刻只能是 blank 或 $l^{‘}_ {2}$
- $\alpha_1(1)=y_{-}^1$ 表示 t=1 时刻为blank的概率.
- $\alpha_1(2)=y_{l_2’}^1$ 表示 t=1 时刻为s中的第二个节点 s_2 的概率,也就是输出序列的第一个节点
- $\alpha_1(s)=0, \forall s>2$ 表示t=1时刻其他节点概率为0
- $\alpha_t(s)=0, \forall s < |l’|-2(T-t)-1$ 对于任何时刻都有部分节点是完全不可能的
t=T 时刻,只有最后两个节点可行。
t=0 时刻,对于节点 $s<|l’|-2T -1(|l’|=2N+1)$ . 其概率为 0.
如果输入序列的长度 T=N(与label等长),则 $|l’|=2T+1$ .
一般情况下 T>>N,s<0
0<t<T 时刻,以特例 T=N 为例, $s<2N+1-2N+2t -1 \rightarrow s<2t$ . 也就是 s<2t 的节点概率都为0.
- 前向递推公式,如果 t 时刻为 s 节点,那么 t-1 时刻可能的节点与 s 节点是否为 blank 有关。
- 如果s节点为blank. 不能跳过非空字符,所以 $\alpha_t$ 仅依赖 $\alpha_{t-1}(s)$, $\alpha_{t-1}(s-1)$ , 不依赖于 $\alpha_{t-1}(s-2)$
- 如果 s = s-2. 相同字符之间必须有空格。公式同上。
- 不属于上述两种情况, $\alpha_t$ 也能依赖于 $\alpha_{t-1}(s-2)$
最终通过公式计算loss, 通过迭代的方法计算T时刻最后两个节点的概率:
$-ln(p(l|x)) = -ln(\alpha_T(|l’|) + \alpha_T(|l’|-1))$
1 | 代码解析 |
References
- Sequence Modeling With CTC, Awni Hannun
CTC算法详解之训练篇 - CTC Algorithm Explained Part 2:Decoding the Network(CTC算法详解之解码篇
ctc loss and decoder