首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow CTC损失: ctc_merge_repeated参数

Tensorflow CTC损失: ctc_merge_repeated参数
EN

Stack Overflow用户
提问于 2017-08-08 12:22:44
回答 1查看 1.3K关注 0票数 1

我使用的是Tensorflow 1.0和它的CTC丢失1。在训练时,我有时会得到“没有找到有效的路径”。警告(对学习有害)。它是,而不是,因为其他Tensorflow用户有时会报告它的高学习率。

在对其进行了一些分析之后,我发现了导致此警告的模式:

  • 将输入序列输入到长度为ctc_loss的seqLen中
  • 输入带有labelLen字符的标签
  • 标签中有numRepeatedChars重复的字符,其中我计算"ab“为0,"aa”为1,"aaa“为2,等等。
  • 警告发生时: seqLen - labelLen < numRepeatedChars

三个例子:

  • Ex.1: label="abb",len(标签)=3,len(inputSequence)=3 => (3-3=0)<1是真实的->警告
  • Ex.2: label="abb",len(标签)=3,len(inputSequence)=4 => (4-3=1)<1是假的->没有警告
  • Ex.3: label="bbb",len(标签)=3,len(inputSequence)=4 => (4-3=1)<2是真实的->警告

当我现在设置ctc_loss参数ctc_merge_repeated=False时,警告将消失。

三个问题:

  • Q1:为什么在出现重复字符时会有警告?我想,只要输入序列不短于目标标签,就没有问题。当重复的字符被合并到标签中时,它就会变得更短,因此输入序列不短的条件仍然有效。
  • Q2:为什么默认设置下的ctc_loss会产生这个警告?重复字符在使用CTCs的领域中很常见,例如手写文本识别(HTR)。
  • Q3:我在做HTR时应该使用哪些设置?当然,标签可以有重复的字符。因此,ctc_merge_repeated=False是有意义的。有什么建议吗?

用于复制警告的Python程序:

代码语言:javascript
复制
import tensorflow as tf
import numpy as np

def createGraph():
    tinputs=tf.placeholder(tf.float32, [100, 1, 65]) # max 100 time steps, 1 batch element, 64+1 classes
    tlabels=tf.SparseTensor(tf.placeholder(tf.int64, shape=[None,2]) , tf.placeholder(tf.int32,[None]), tf.placeholder(tf.int64,[2])) # labels
    tseqLen=tf.placeholder(tf.int32, [None]) # list of sequence length in batch
    tloss=tf.reduce_mean(tf.nn.ctc_loss(labels=tlabels, inputs=tinputs, sequence_length=tseqLen, ctc_merge_repeated=True)) # ctc loss
    return (tinputs, tlabels, tseqLen, tloss)

def getNextBatch(nc): # next batch with given number of chars in label
    indices=[[0,i] for i in range(nc)]
    values=[i%65 for i in range(nc)]
    values[0]=0
    values[1]=0 # TODO: (un)comment this to trigger warning
    shape=[1, nc]
    labels=tf.SparseTensorValue(indices, values, shape)
    seqLen=[nc]
    inputs=np.random.rand(100, 1, 65)
    return (labels, inputs, seqLen) 


(tinputs, tlabels, tseqLen, tloss)=createGraph()

sess=tf.Session()
sess.run(tf.global_variables_initializer())

nc=3 # number of chars in label
print('next batch with 1 element has label len='+str(nc))
(labels, inputs, seqLen)=getNextBatch(nc)
res=sess.run([tloss], { tlabels: labels, tinputs:inputs, tseqLen:seqLen } )

这是来自C++ Tensorflow代码2的警告:

代码语言:javascript
复制
// It is possible that no valid path is found if the activations for the
// targets are zero.
if (log_p_z_x == kLogZero) {
    LOG(WARNING) << "No valid path found.";
    dy_b = y;
    return;
}

1

2

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-10-02 23:29:17

好的,明白了,这不是一个bug,这就是CTC的工作方式:让我们举一个警告发生的例子:输入序列的长度是2,标签是"aa“(也就是长度2)。

现在产生"aa“的最短路径是a->空白->a(长度3)。但对于标签"ab",最短路径是a->b (长度2)。这说明了为什么像"aa“中的重复标签的输入序列必须更长。这只是通过插入空格在CTC中编码重复标签的方式。

因此,当确定输入大小时,标签重复次数会减少允许的标签的最大长度。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45568266

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档