ctc loss and decoder

想要详细的了解ctc loss,建议直接转至大佬的博客 $\rightarrow$ https://xiaodu.io/ctc-explained/. 我个人写这个笔记的目的在于最近要对ctc loss进行魔改时,发现之前的细节都忘了。所以按照自己已有的基础上重新整理了一遍。除此之外,大佬的博客里面并没有代码解析。这里有ctc loss 和 ctc decode的python代码实现,所以想要对ctc loss进行魔改的,可以再过一遍我这篇文章~
Why ctc loss, ctc loss vs cross entropy

现实中有很多任务都可以看作是序列到序列的对齐训练。主要可以分为两类:

  1. NLP领域常见的机器翻译和对话。对于这类任务,在训练阶段,我们通常使用cross-entropy + teacher forcing来训练模型。这类任务的特点是源序列和目标序列没有严格的对齐关系。他们本质上可以看作是 conditional language model. 也就是目标序列作为条件语言模型,更看重连贯性、流利度,其次是与源序列的对应关系(所以他们会有多样性研究)。

  2. 识别领域常见的语音识别,OCR,手语识别。对于这类任务,我们则主要使用ctc loss作为损失函数来训练模型。这类任务的特点是源序列和目标序列有严格的对齐关系。对于语音或手语,目标序列有语言模型的特点,但是更看重与源序列的准确的对应关系。
    第二类任务其实也可以用cross entropy来训练,但是往往效果不太好。我们可以发现,对于第二类任务,最理想的情况是将源序列先进行分割,这样单独的对某一个音节,手语或者字符进行识别,准确率就会很高了。但是现实是,源序列更多的是未分割的情况。针对这类任务,[Alex Graves, 2006] 提出了Connectionist Temporal Classification.

使用ctc进行训练有两个要求:

  1. 源序列长度 >> 目标序列长度
  2. 源序列的order与目标序列的order一致,且存在顺序对齐的关系

ctc training

如何计算 ctc loss, 这篇博客CTC Algorithm Explained Part 1写的非常非常赞(以下简称为ctc explain blog),细节看这个就好了。

这里简单的概括下思想。 给定源序列 $X={x_1,x_2,..,x_T}$ 和目标序列 $Y={y_1,..y_N}$ . 其中 $T>>N$.
根据极大似然估计原理,我们的目标是找到使得 P(Y|X;W) 最大化的W。这里 X,Y 存在多对一的关系。
我们假设存在这样的路径 $\pi={\pi_1, …,\pi_T}$, $\pi_i\in |V’=V+blank|$ 与源序列一一对应(V是词表)。并且存在这样的映射关系 $\beta(\pi) = Y$ , 其映射法则就是去重和去掉blank(这个映射法则有其物理意义:就是一个音节或手语动作会包含多帧,以及存在中间停顿或无意义帧等情况.)

因此,优化的目标模型可以转换为:
$P(Y|X)=\sum_{\pi\in \beta^{-1}(Y)}P(\pi|X;W)$

所以我们的目标现在转换成了最大化满足 $\pi\in \beta^{-1}(Y)$ 的所有路径的概率。现在问题就转变成如何找到 Y 对应的所有路径。这是一个动态规划的问题。
Graves 根据HMM的前向后向算法,利用动态规划的方法来求解。根据目标序列Y和X的帧数构建一个表格:

纵轴是将目标序列扩展为前后都有blank的序列 $l’=(-, l_1, -, l_2,…,- ,l_N, -)$ 。如果 T=N 时,那么Y和X就是一一对应了。X的序列越长,这个表格的搜索空间越大,存在的可能的路径就越多。

如何找到所有的合法路径,先定义路径规则,然后找到递归条件便能通过动态规划的方法解决,具体细节参见ctc explain blog.

路径规则:

  • 转换只能往右下方向,其他方向不允许
  • 相同的字符之间至少有一个空格
  • 非空字符不能被跳过(不然最终就不是apple了
  • 起点必须从前两个字符开始
  • 重点必须落在结尾两个字符

这里以前向算法为例来解释:

其中的符号:x表示输入序列,z表示输出序列,s表示纵轴的节点(2T+1个)
初始条件, t=1时刻只能是 blank 或 $l^{‘}_ {2}$

  • $\alpha_1(1)=y_{-}^1$ 表示 t=1 时刻为blank的概率.
  • $\alpha_1(2)=y_{l_2’}^1$ 表示 t=1 时刻为s中的第二个节点 s_2 的概率,也就是输出序列的第一个节点
  • $\alpha_1(s)=0, \forall s>2$ 表示t=1时刻其他节点概率为0
  1. $\alpha_t(s)=0, \forall s < |l’|-2(T-t)-1$ 对于任何时刻都有部分节点是完全不可能的
  • t=T 时刻,只有最后两个节点可行。

  • t=0 时刻,对于节点 $s<|l’|-2T -1(|l’|=2N+1)$ . 其概率为 0.

    • 如果输入序列的长度 T=N(与label等长),则 $|l’|=2T+1$ .

    • 一般情况下 T>>N,s<0

  • 0<t<T 时刻,以特例 T=N 为例, $s<2N+1-2N+2t -1 \rightarrow s<2t$ . 也就是 s<2t 的节点概率都为0.

  1. 前向递推公式,如果 t 时刻为 s 节点,那么 t-1 时刻可能的节点与 s 节点是否为 blank 有关。
  • 如果s节点为blank. 不能跳过非空字符,所以 $\alpha_t$ 仅依赖 $\alpha_{t-1}(s)$, $\alpha_{t-1}(s-1)$ , 不依赖于 $\alpha_{t-1}(s-2)$

  • 如果 s = s-2. 相同字符之间必须有空格。公式同上。

  • 不属于上述两种情况, $\alpha_t$ 也能依赖于 $\alpha_{t-1}(s-2)$

最终通过公式计算loss, 通过迭代的方法计算T时刻最后两个节点的概率:
$-ln(p(l|x)) = -ln(\alpha_T(|l’|) + \alpha_T(|l’|-1))$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
代码解析
import numpy as np
from six.moves import xrange


import numpy as np
import editDistance as ed
import heapq as hq
from six.moves import xrange


def ctc_loss(params, seq, blank=0, is_prob=True):
"""
params: [vocab_size, T], logits.softmax(-1). T 是输入序列的长度,vocab_size是词表大小。
seq: [seq_len] 输出序列的长度。

CTC loss function.
params - n x m matrix of n-D probability distributions over m frames.
seq - sequence of phone id's for given example.
is_prob - whether params have already passed through a softmax
Returns objective and gradient.
"""
seqLen = seq.shape[0] # Length of label sequence (# phones)
numphones = params.shape[0] # Number of labels
L = 2 * seqLen + 1 # Length of label sequence with blanks, 拓展后的 l'.
T = params.shape[1] # Length of utterance (time)

# 建立表格 l' x T.
alphas = np.zeros((L, T)) # 前向概率
betas = np.zeros((L, T)) # 后向概率

# 这里dp的map:
# 横轴为 2*seq_len+1, 也就是 ground truth label中每个token前后插入 blank
# 纵轴是 T frames

# logits 转换为概率
if not is_prob:
# if not probs, params is logits without softmax.
params = params - np.max(params, axis=0)
params = np.exp(params)
params = params / np.sum(params, axis=0)

# Initialize alphas and forward pass

# 初始条件:T=0时,只能为 blank 或 seq[0]
alphas[0, 0] = params[blank, 0]
alphas[1, 0] = params[seq[0], 0]
# T=0, alpha[:, 0] 其他的全部为 0


c = np.sum(alphas[:, 0])
alphas[:, 0] = alphas[:, 0] / c # 这里 T=0 时刻所有可能节点的概率要归一化

llForward = np.log(c) # 转换为log域

for t in xrange(1, T):
# 第一个循环: 计算每个时刻所有可能节点的概率和
start = max(0, L - 2 * (T - t)) # 对于时刻 t, 其可能的节点.与公式2一致。
end = min(2 * t + 2, L) # 对于时刻 t,最大节点范围不可能超过 2t+2
for s in xrange(start, L):
l = (s - 1) / 2
# blank,节点s在偶数位置,意味着s为blank
if s % 2 == 0:
if s == 0: # 初始位置,单独讨论
alphas[s, t] = alphas[s, t - 1] * params[blank, t]
else:
alphas[s, t] = (alphas[s, t - 1] + alphas[s - 1, t - 1]) * params[blank, t]
# s为奇数,非空
# l = (s-1/2) 就是 s 所对应的 lable 中的字符。
# ((s-2)-1)/2 = (s-1)/2-1 = l-1 就是 s-2 对应的lable中的字符
elif s == 1 or seq[l] == seq[l - 1]:
alphas[s, t] = (alphas[s, t - 1] + alphas[s - 1, t - 1]) * params[seq[l], t]
else:
alphas[s, t] = (alphas[s, t - 1] + alphas[s - 1, t - 1] + alphas[s - 2, t - 1]) \
* params[seq[l], t]

# normalize at current time (prevent underflow)
c = np.sum(alphas[start:end, t])
alphas[start:end, t] = alphas[start:end, t] / c
llForward += np.log(c)
return llForward
ctc_beam_search 解码
def ctc_beam_search_decode(probs, beam_size=5, blank=0):
"""
:param probs: The output probabilities (e.g. post-softmax) for each
time step. Should be an array of shape (time x output dim).
:param beam:
:param blank:
:return:
"""
# T表示时间,S表示词表大小
T, S = probs.shape

# 求概率的对数
probs = np.log(probs)

# Elements in the beam are (prefix, (p_blank, p_no_blank))
# Initialize the beam with the empty sequence, a probability of
# 1 for ending in blank and zero for ending in non-blank
# (in log space).
# 每次总是保留beam_size条路径
beam = [(tuple(), ((0.0, NEG_INF), tuple()))]

for t in range(T): # Loop over time
# A default dictionary to store the next step candidates.
next_beam = make_new_beam()

for s in range(S): # Loop over vocab
# print(s)
p = probs[t, s] # t时刻,符号为s的概率

# The variables p_b and p_nb are respectively the
# probabilities for the prefix given that it ends in a
# blank and does not end in a blank at this time step.
for prefix, ((p_b, p_nb), prefix_p) in beam: # Loop over beam
# p_b表示前缀最后一个是blank的概率,p_nb是前缀最后一个非blank的概率
# If we propose a blank the prefix doesn't change.
# Only the probability of ending in blank gets updated.

if s == blank:
# 增加的字母是blank
# 先取出对应prefix的两个概率,然后更后缀为blank的概率n_p_b
(n_p_b, n_p_nb), _ = next_beam[prefix] # -inf, -inf
n_p_b = logsumexp(n_p_b, p_b + p, p_nb + p) # 更新后缀为blank的概率
next_beam[prefix] = ((n_p_b, n_p_nb), prefix_p) # s=blank, prefix不更新,因为blank要去掉的。
# print(next_beam[prefix])
continue

# Extend the prefix by the new character s and add it to
# the beam. Only the probability of not ending in blank
# gets updated.
end_t = prefix[-1] if prefix else None
n_prefix = prefix + (s,) # 更新 prefix, 它是一个tuple
n_prefix_p = prefix_p + (p,)
# 先取出对应 n_prefix 的两个概率, 这个是更新了blank概率之后的 new 概率
(n_p_b, n_p_nb), _ = next_beam[n_prefix] # -inf, -inf

if s != end_t:
# 如果s不和上一个不重复,则更新非空格的概率
n_p_nb = logsumexp(n_p_nb, p_b + p, p_nb + p)
else:
# 如果s和上一个重复,也要更新非空格的概率
# We don't include the previous probability of not ending
# in blank (p_nb) if s is repeated at the end. The CTC
# algorithm merges characters not separated by a blank.
n_p_nb = logsumexp(n_p_nb, p_b + p)

# If s is repeated at the end we also update the unchanged
# prefix. This is the merging case.
if s == end_t:
(n_p_b, n_p_nb), n_prefix_p = next_beam[prefix]
n_p_nb = logsumexp(n_p_nb, p_nb + p)
# 如果是s=end_t,则prefix不更新
next_beam[prefix] = ((n_p_b, n_p_nb), n_prefix_p)
else:
# *NB* this would be a good place to include an LM score.
next_beam[n_prefix] = ((n_p_b, n_p_nb), n_prefix_p)
# print(t, next_beam.keys())
# Sort and trim the beam before moving on to the
# next time-step.
# 根据概率进行排序,每次保留概率最高的beam_size条路径
beam = sorted(next_beam.items(),
key=lambda x: logsumexp(*x[1][0]),
reverse=True)
beam = beam[:beam_size]

# best = beam[0]
# return best[0], -logsumexp(*best[1][0]), best[1][1]

pred_lens = [len(beam[i][0]) for i in range(beam_size)]
max_len = max(pred_lens)
pred_seq, scores, pred_pobs = np.zeros((beam_size, max_len), dtype=np.int32), \
[], np.zeros((beam_size, max_len))
for bs in range(beam_size):
pred_seq[bs][:pred_lens[bs]] = beam[bs][0]
scores.append(-logsumexp(*beam[bs][1][0]))
pred_pobs[bs][:pred_lens[bs]] = np.exp(beam[bs][1][1])
return pred_seq, scores, pred_pobs


# 因为代码中为了避免数据下溢,都采用的是对数概率,所以看起来比较繁琐
def logsumexp(*args):
"""
Stable log sum exp.
"""
if all(a == NEG_INF for a in args):
return NEG_INF
a_max = max(args)
lsp = math.log(sum(math.exp(a - a_max) for a in args)) # 概率相加再取log,为避免数值下溢
return a_max + lsp


# 创建一个新的beam
def make_new_beam():
fn = lambda: ((NEG_INF, NEG_INF), tuple())
return collections.defaultdict(fn)

if __name__ == "__main__":
import ctcdecode, time

np.random.seed(3)

seq_len = 50
output_dim = 20

probs = np.random.rand(seq_len, output_dim)
# probs = np.random.rand(time, output_dim)
# probs = np.random.rand(time, output_dim)
probs = probs / np.sum(probs, axis=1, keepdims=True)

start_time = time.time()
labels, score, labels_p = MPGenerate.ctc_beam_search_decode(probs, beam_size=5, blank=0)
print("labels:", labels[0], len(labels[0]))
print("labels_p: ", labels_p[0], len(labels_p[0]))
print("Score {:.3f}".format(score[0]))
print("First method time: ", time.time() - start_time)

dec_logits = torch.FloatTensor(probs).unsqueeze(0)
len_video = torch.LongTensor([seq_len])
decoder_vocab = [chr(x) for x in range(20000, 20000 + output_dim)]

second_time = time.time()
decoder = ctcdecode.CTCBeamDecoder(decoder_vocab, beam_width=5, blank_id=0, num_processes=10)

pred_seq, scores, _, out_seq_len = decoder.decode(dec_logits, len_video)

# pred_seq: [batch, beam, length]
# out_seq_len: [batch, beam]
print(pred_seq[0, 0, :][:out_seq_len[0, 0]])
print(out_seq_len[0, 0])
print("Score {:.3f}".format(scores[0, 0]))
print("Second method time: ", time.time() - second_time)

References

作者

Xie Pan

发布于

2020-09-13

更新于

2021-07-01

许可协议

评论