700字范文,内容丰富有趣,生活中的好帮手!
700字范文 > CTC束搜索解码原理和Pytorch实现(CTC Prefix BeamSearch Decode)

CTC束搜索解码原理和Pytorch实现(CTC Prefix BeamSearch Decode)

时间:2019-01-25 22:01:31

相关推荐

CTC束搜索解码原理和Pytorch实现(CTC Prefix BeamSearch Decode)

CTC解码在推断时,同一个标签序列对应的原生序列的结尾会有两种情况:1.以字符结尾;2.以blank结尾。不同的结尾往下增长时的缩放策略不同,比如以字符结尾:*a遇到a会缩放为*a;以blank(用“-”表示)结尾:*a-遇到a会被缩放为*aa。所以在增长过程的每一步,标签序列的概率都会使用两个变量存储,一个负责累加以字符结尾的原生序列概率,另一个负责累加以blank结尾的原生序列概率,两者相互独立,互无交集。增长后,再将这两个概率相加(log_sum_exp)表示这一个标签序列的总概率。然后取top beam_size后再往下增长。

序列增长时会有四种情况:

原生序列结尾任意,当前值为blank, 标签序列不变, 更新以blank结尾的概率;原生序列结尾为blank,当前值为相同字符(指与目前标签序列的最后一个字符相同), 标签序列更新, 更新非blank概率;原生序列结尾为字符,当前值为相同字符, 标签序列不变, 更新非blank概率;原生序列结尾任意,当前值为不同字符, 标签序列更新, 更新非blank概率。

注:

1.原生序列是指未缩放的序列,如aa-bbc-,aabbcc 对应的标签序列都为abc。

2.这里的概率指得都是对数概率:lp=log(softmax(logits))。所以原生序列增长时,其概率lp用“+”更新,相当于概率积后取log。而原生序列和标签序列是多对一关系,同一个标签序列的概率用其对应的多个原生序列概率的log_sum_exp表示(log(exp(lp1)+exp(lp2),...exp(lpk)),相当于概率和后再规范为对数概率表示。

import mathdef log_sum_exp(lps):_inf = -float('inf')if all(lp == _inf for lp in lps):return _infmlp = max(lps)return mlp + math.log(sum(math.exp(lp - mlp) for lp in lps))def beam_search_ctc(probs,bms=10,blank=0):'''probs: 概率空间,shape为[sequence_len,vocab_size]的torch tensorbms: beam_sizeblank: blank index'''_inf = -float("inf")seqs =[((idx.item(),),(lp.item(),_inf)) if idx.item()!=blankelse (tuple(),(_inf,lp.item()))for lp,idx in zip(*probs[0].topk(bms))]for i in range(1,probs.size(0)):new_seqs = {}for seq,(lps,blps) in seqs: last = seq[-1] if len(seq) > 0 else Nonefor lp, idx in zip(*probs[i].topk(bms)):lp=lp.item()idx=idx.item() if idx == blank :nlps,nblps= new_seqs.get(seq,(_inf,_inf))new_seqs[seq]=(nlps,log_sum_exp([nblps,lps+lp,blps+lp]))elif idx ==last:#aanlps,nblps= new_seqs.get(seq,(_inf,_inf))new_seqs[seq]=(log_sum_exp([nlps,lps+lp]),nblps)#a-anew_seq = seq + (idx,)nlps,nblps= new_seqs.get(new_seq,(_inf,_inf))new_seqs[new_seq]=(log_sum_exp([nlps,blps+lp]),nblps)else:new_seq = seq + (idx,)nlps,nblps= new_seqs.get(new_seq,(_inf,_inf))new_seqs[new_seq]=(log_sum_exp([nlps,lps+lp,blps+lp]),nblps)new_seqs = sorted(new_seqs.items(),key=lambda x: log_sum_exp(list(x[1])),reverse=True)seqs = new_seqs[:bms]return seqs

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。