首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >具有多类召回自定义度量的Tensorflow 2 ModelCheckpoint回调

具有多类召回自定义度量的Tensorflow 2 ModelCheckpoint回调
EN

Stack Overflow用户
提问于 2020-06-03 09:48:59
回答 1查看 277关注 0票数 0

我正在为多类分类任务(num_classes=7)构建CNN分类器.由于不平衡和主题区域,我对此任务的目标度量是跨类的宏观平均召回。

随着模型的训练,我想通过在每个时代结束时保存模型来检查它,如果验证多类宏召回被评估为高于以前在整个时代中看到的最高值。我相信这会分两个阶段进行:

  1. 创建一个自定义度量,用于计算多类场景中每个时代结束时验证数据的平均召回量(
  2. ),创建一个跟踪自定义度量的ModelCheckpoint回调,并在模型超过以前的max.

时保存该模型。

有人会有这样或类似的例子吗?我对宏平均多类召回的自定义度量的实现更感兴趣,因为我相信,一旦在model.compile()中定义了这个度量,回调就可以很容易地完成。

EN

回答 1

Stack Overflow用户

发布于 2020-06-05 09:07:28

我使用this post实现了自定义度量,并做了一些调整,例如计算了running mean。以下是自定义度量的代码:

代码语言:javascript
复制
import tensorflow.keras.backend as K
from tensorflow.keras.metrics import Metric

class MacroAverageRecall( Metric ):
    """Custom metric for calculating multiclass recall during         
training"""
    def __init__(self,
                 num_classes,
                 batch_size,
                 name='multiclass_recall',
                 **kwargs):
        super( MacroAverageRecall, self ).__init__( name=name, **kwargs )
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.num_batches = 0
        self.average_recall = self.add_weight( name="recall", initializer="zeros" )

    def update_state(self, y_true, y_pred, sample_weight=None):
        recall = 0
        pred = K.argmax( y_pred, axis=-1 )
        true = K.argmax( y_true, axis=-1 )

        for i in range( self.num_classes ):
            # Find where the pred equals the class
            predicted_instances_bool = K.equal(
                pred,
                i
            )
            # Find where the labels equals the class
            true_instances_bool = K.equal(
                true,
                i
            )
            # Converting tensors of bools to int (1,0)
            predicted_instances = K.cast(
                predicted_instances_bool,
                'float32'
            )
            true_instances = K.cast(
                true_instances_bool,
                'float32'
            )
            # Reshaping tensors
            true_reshaped = K.reshape(
                true_instances,
                (1, -1)
            )
            predicted_reshaped = K.reshape(
                predicted_instances,
                (-1, 1)
            )
            # Find true positives
            true_positives = K.dot(
                true_reshaped,
                predicted_reshaped
            )
            # Compute the true positive
            pred_true_pos = K.sum(
                true_positives
            )
            # divide by all positives in t
            all_true_positives = (K.sum( true_instances ) + K.epsilon())
            class_recall = pred_true_pos / all_true_positives
            recall += class_recall

        self.num_batches += 1
        avg_recall = recall / self.num_classes
        recall_update = (avg_recall - self.average_recall) / self.num_batches
        self.average_recall.assign_add( recall_update )

    def result(self):
        return self.average_recall

    def reset_states(self):
        # The state of the metric will be reset at the start of each epoch.
        self.average_recall.assign( 0. )

以及在模型培训中使用的检查点:

代码语言:javascript
复制
callbacks.ModelCheckpoint(
            filepath=os.path.join(
                self._metadata['checkpoint_directory'],
                f'checkpoint-{self._metadata["create_time"]}.h5' ),
            save_best_only=True if self._val else False,
            monitor='val_multiclass_recall',
            mode='max',
            verbose=1 )
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62169893

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档