首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tensorflow角多类分类的召回与精度度量

tensorflow角多类分类的召回与精度度量
EN

Stack Overflow用户
提问于 2022-09-01 05:01:17
回答 2查看 431关注 0票数 0

我们在分类上遇到了问题。我们想找出每个类的召回、精确度量。我们发现在tf.keras.metrics中有内置的精确性和召回度量。但这些指标似乎只适用于二进制分类。在我们的模型中,最后一层是具有活动函数“softmax”的密集层。损失函数是sparse_categorical_crossentropy,因为我们对y使用了类标签。

代码语言:javascript
复制
output = Dense(3, activation='softmax')(attention_mul)
model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy, optimizer='Adam', metrics=['accuracy'])

预测结果的输出是每类概率的向量,例如0.3,0.5,0.2。为了得到类标签,我们需要对预测结果应用np.argmax()。而内置的召回和精确度量则接受输入的类标签。

代码语言:javascript
复制
m = tf.keras.metrics.Recall()
m.update_state([0, 1, 1, 1], [1, 0, 1, 1])
m.result().numpy()

是否有任何解决方案来获得精确性和召回指标,并在培训的每一个阶段进行监控?

EN

回答 2

Stack Overflow用户

发布于 2022-09-01 07:49:14

对于多类分类问题,Keras中的精确性和召回是不可用的原因。由于度量是按批处理计算的,因此这两个度量的结果可能不准确。实际上,Keras有一个精确和召回的实现,正是因为这个原因才决定删除。

但是,如果您真的想要的话,您可以为精确性创建自定义度量标准,并将这些指标回忆起来并传递给编译。

Keras GitHub中删除的指标如下:

代码语言:javascript
复制
def precision(y_true, y_pred):
    """Precision metric.
    Only computes a batch-wise average of precision.
    Computes the precision, a metric for multi-label classification of
    how many selected items are relevant.
    """
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    precision = true_positives / (predicted_positives + K.epsilon())
    return precision


def recall(y_true, y_pred):
    """Recall metric.
    Only computes a batch-wise average of recall.
    Computes the recall, a metric for multi-label classification of
    how many relevant items are selected.
    """
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    recall = true_positives / (possible_positives + K.epsilon())
    return recall

将指标添加到compile

代码语言:javascript
复制
model.compile(loss=tf.keras.losses.sparse_categorical_crossentropy, optimizer='Adam', metrics=['accuracy', precision, recall])

这样,您就可以按照您的要求在每个时代监视这两个指标。

票数 0
EN

Stack Overflow用户

发布于 2022-09-07 09:31:30

有keras度量项目https://github.com/netrack/keras-metrics。但是,目前版本的Tensorflow (如2.7 )没有得到维护,而且已经过时。在这个项目的启发下,我终于找到了解决方案:我们可以定制度量函数。以下是代码:

代码语言:javascript
复制
    def recall(y_true, y_pred, c):
        y_true = K.flatten(y_true)
        pred_c = K.cast(K.equal(K.argmax(y_pred, axis=-1), c), K.floatx())
        true_c = K.cast(K.equal(y_true, c), K.floatx())
        true_positives = K.sum(pred_c * true_c)
        possible_postives = K.sum(true_c)
        return true_positives / (possible_postives + K.epsilon())


    def precision(y_true, y_pred, c):
        y_true = K.flatten(y_true)
        pred_c = K.cast(K.equal(K.argmax(y_pred, axis=-1), c), K.floatx())
        true_c = K.cast(K.equal(y_true, c), K.floatx())
        true_positives = K.sum(pred_c * true_c)
        pred_positives = K.sum(pred_c)
        return true_positives / (pred_positives + K.epsilon())

    def recall_c1(y_true, y_pred):
        return recall(y_true, y_pred, 1)

    def precision_c1(y_true, y_pred):
        return precision(y_true, y_pred, 1)
    
    def recall_c2(y_true, y_pred):
        return recall(y_true, y_pred, 2)

    def precision_c2(self, y_true, y_pred):
        return precision(y_true, y_pred, 2)

我们可以使用precision_c1,recall_c1来表示类别1的精确性和召回度量,类别2的precision_c2,recall_c2。更多的类别也可以通过将class_id值c传递给函数class_id()和精度()来支持。下面是模型培训期间的示例输出:

代码语言:javascript
复制
Epoch 2/2000
24/24 - 35s - loss: 1.1322 - accuracy: 0.0675 - recall_c1: 0.9962 - precision_c1: 0.0676 - recall_c2: 0.0054 - precision_c2: 0.0402 - val_loss: 1.1263 - val_accuracy: 0.0357 - val_recall_c1: 1.0000 - val_precision_c1: 0.0344 - val_recall_c2: 0.0000e+00 - val_precision_c2: 0.0000e+00 - 35s/epoch - 1s/step
Epoch 3/2000
24/24 - 35s - loss: 1.1321 - accuracy: 0.0678 - recall_c1: 0.9873 - precision_c1: 0.0679 - recall_c2: 0.0178 - precision_c2: 0.0876 - val_loss: 1.1254 - val_accuracy: 0.0382 - val_recall_c1: 0.8761 - val_precision_c1: 0.0346 - val_recall_c2: 0.2432 - val_precision_c2: 0.0948 - 35s/epoch - 1s/step
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73564461

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档