首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何使用tf.nn.ctc_loss计算全空白序列的CTC损失?

如何使用tf.nn.ctc_loss计算全空白序列的CTC损失?
EN

Stack Overflow用户
提问于 2017-10-10 02:22:38
回答 1查看 1.2K关注 0票数 5

手工计算具有所有空白的序列的CTC损失是简单的。但是,我找不到一种使用tf.nn.ctc_loss API来完成此任务的方法。我是否遗漏了什么,或者tf.nn.ctc_loss实现缺少此功能?当批处理中的少数序列没有输出符号时,此功能是必需的。

I reported this on github,它是关闭的,没有应答。

环境: tf版本1.3,CPU版本;python 3.5/3.6;Win10/Ubuntu 16.04。

首先,我们从代码开始:

代码语言:javascript
复制
import tensorflow as tf
num_classes, batch_size, seq_len = 3, 1, 2
labels = tf.SparseTensor(indices=[[0,0]], values=[0], dense_shape=[1,1])
inputs = tf.zeros([seq_len, batch_size, num_classes])
loss = tf.nn.ctc_loss(labels, inputs, [seq_len])
print(tf.InteractiveSession().run(loss))

tf.nn.ctc_loss按照预期运行,并打印正确答案: 1.09861231

问题一:

如何计算全空序列的ctc损失?tf.nn.ctc_loss应用程序接口要求值< num_labels,所以我们没有办法实现它?如果我将上面示例中的值更改为num_classes -1(保留的空白ID),tf.nn.ctc_loss没有任何抱怨,并返回错误的答案: 0.81093025!正确答案是2*log(3)。重现第一个问题的代码如下:

代码语言:javascript
复制
import tensorflow as tf
num_classes, batch_size, seq_len = 3, 1, 2
labels = tf.SparseTensor(indices=[[0,0]], values=[2], dense_shape=[1,1])
inputs = tf.zeros([seq_len, batch_size, num_classes])
loss = tf.nn.ctc_loss(labels, inputs, [seq_len])
print(tf.InteractiveSession().run(loss))

问题二:

让我们将序列长度改为1,如下所示

代码语言:javascript
复制
import tensorflow as tf
num_classes, batch_size, seq_len = 3, 1, 1
labels = tf.SparseTensor(indices=[[0,0]], values=[2], dense_shape=[1,1])
inputs = tf.zeros([seq_len, batch_size, num_classes])
loss = tf.nn.ctc_loss(labels, inputs, [seq_len])
print(tf.InteractiveSession().run(loss))

然后再次运行代码。这段代码在Ubuntu中给出了正确的答案log(3),但在Win10中崩溃,并显示消息:内核已死,正在重新启动。

EN

回答 1

Stack Overflow用户

发布于 2019-10-03 22:06:22

我不知道TF v1.3,但在TF v2.0中,包含所有空白的序列的CTC损失可以通过使用密集张量而不是稀疏张量来计算。在TFv1.x中,有ctc_loss_v2,我相信它具有相同的行为。

代码语言:javascript
复制
y_true = np.array([[0, 0, 0, 0]])
y_pred = np.zeros((1, 2, 3), np.float32)
input_length = np.array([2])
label_length = np.array([0])
cost = tf.nn.ctc_loss(y_true, y_pred, label_length, input_length)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/46652720

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档