首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >运行CTC损失函数

运行CTC损失函数
EN

Stack Overflow用户
提问于 2020-08-18 07:04:38
回答 1查看 510关注 0票数 0

我想尝试一下莎士比亚数据集上的CTC损失函数,在计算损失时,预测的张量形状是(64,100,65),这与标签形状(64,100 )不匹配。所以我用一些数学来转换维数,但有一个错误。

代码语言:javascript
复制
def loss(labels, logits):
  return tf.keras.losses.categorical_crossentropy(labels, logits)

example_batch_loss  = loss(labels=target_example_batch, logits=tf.math.argmax(tf.convert_to_tensor(example_batch_predictions), axis=-1, output_type=tf.int64))

误差

不能将Mul计算为输入#1(基于零的),它应该是一个int64张量,但它是一个双张量Op:Mul。

请帮助我找到一个解决办法,使用反恐委员会损失。

EN

回答 1

Stack Overflow用户

发布于 2020-08-18 08:13:09

您正在输入模型输出的are最大值,即输出值最高的索引。CTC损失(与大多数损失函数一样)适用于logits,即该模型产生的非归一化概率分布。因此,对形状(64,100,65)和目标(64,100)进行预测是没有错的。

但是,请注意,CTC旨在处理比目标更长的模型输出时的情况。典型的用例是语音识别,在这种情况下,有大量的信号窗口与相对较少的音素匹配。如果输出长度和目标长度相同,则CTC退化为标准交叉熵。

假设example_batch_predictions是您的模型输出,然后再由softmax规范它,那么您应该这样做:

代码语言:javascript
复制
example_batch_loss  = loss(labels=target_example_batch, logits=example_batch_predictions, axis=-1, output_type=tf.int64))
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63463448

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档