首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >实现CTC损失核

实现CTC损失核
EN

Stack Overflow用户
提问于 2021-09-28 12:29:12
回答 1查看 51关注 0票数 0

考虑到您有一个类似于以下内容的基本模型:

代码语言:javascript
复制
input_layer = layers.Input(shape=(50,20))
layer = layers.Dense(123, activation = 'relu')
layer = layers.LSTM(128, return_sequences = True)(layer)
outputs = layers.Dense(20, activation='softmax')(layer)
model = Model(input_layer,outputs)

您将如何实现CTC损失?我在OCR上尝试了keras代码教程中的一些内容,如下所示:

代码语言:javascript
复制
class CTCLayer(layers.Layer):
    def __init__(self, name=None):
        super().__init__(name=name)
        self.loss_fn = keras.backend.ctc_batch_cost

    def call(self, y_true, y_pred):
        # Compute the training-time loss value and add it
        # to the layer using `self.add_loss()`.
        batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
        input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
        label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")

        input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
        label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")

        loss = self.loss_fn(y_true, y_pred, input_length, label_length)
        self.add_loss(loss)

        # At test time, just return the computed predictions
        return y_pred
labels = layers.Input(shape=(None,), dtype="float32")
input_layer = layers.Input(shape=(50,20))
layer = layers.Dense(123, activation = 'relu')
layer = layers.LSTM(128, return_sequences = True)(layer)
outputs = layers.Dense(20, activation='softmax')(layer)
output = CTCLayer()(labels,outputs)
model = Model(input_layer,outputs)

然而,当它涉及到model.fit部分时,它开始崩溃,因为我不知道如何给模型提供“标签”输入层的东西。我认为教程中的方法非常明确,那么实现CTC损失的更好和更有效的方法是什么?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-09-28 12:44:31

你唯一做错的是模型创建model = Model(input_layer,outputs),它应该是model = Model([input_layer,labels],output),如果你不想有2个输入,你也可以用tf.nn.ctc_loss作为损失来编译模型

代码语言:javascript
复制
def my_loss_fn(y_true, y_pred):
  loss_value = tf.nn.ctc_loss(y_true, y_pred, y_true_length, y_pred_length, 
  logits_time_major = False)
  return tf.reduce_mean(loss_value, axis=-1)

model.compile(optimizer='adam', loss=my_loss_fn)

请注意,上面的代码未经过测试,您需要找到y_pred和y_true长度,但您可以像在ctc层中那样执行此操作。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69361779

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档