首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在tf.keras.metrics.Recall中使用这些参数

在tf.keras.metrics.Recall中使用这些参数
EN

Stack Overflow用户
提问于 2019-09-20 07:36:09
回答 1查看 600关注 0票数 1

我想知道当指定多个阈值时如何计算召回。下面是来自docs/python/tf/keras/指标/召回的描述片段

阈值:(可选)浮点值或浮阈值的python列表/元组。将阈值与预测值进行比较,以确定预测的真值(即高于阈值为真,低于为假)。为每个阈值生成一个度量值。如果既没有设置阈值,也没有设置top_k,则默认情况是使用thresholds=0.5计算召回。

我试图传递一个包含3个阈值的列表,按照我预期的描述,将生成3个召回值(即每个阈值一次召回),但它不是这样工作的,只生成了一个召回度量。

代码语言:javascript
复制
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Embedding,Flatten,Dense
from tensorflow.keras.layers import LSTM
from tensorflow.keras.layers import Bidirectional
from tensorflow.keras.layers import Dropout
from tensorflow.keras import layers
model=Sequential()
model.add(Embedding(len(tokens)+1,embedding_dim,input_length=MAX_TEXT_LEN,weights=[embedding_matrix]))
model.add(LSTM(128))
model.add(Dropout(0.5))
model.add(Dense(9,activation='sigmoid'))
opt=tf.keras.optimizers.Adam(lr=0.0001)
model.compile(loss='binary_crossentropy', optimizer=opt, metrics=[tf.keras.metrics.Recall(thresholds=[0.2,0.4,0.8]))
EN

回答 1

Stack Overflow用户

发布于 2022-02-13 07:17:23

更新

为了能够在训练期间看到每个阈值的度量值,可以编写一个自定义回调,它将在每个时期结束时记录每个阈值的值。

代码语言:javascript
复制
class CustCallback(callbacks.Callback):
    def on_epoch_end(self, epoch, logs=None):
        print(
            "The validation recall for epoch {} is {} ".format(
                epoch, logs['val_recall_mult_thr']
            )
        )

history = model.fit(
    X_train, y_train,
    validation_data=(X_test, y_test),
    callbacks=[CustCallback()],
    batch_size=64, epochs=3, verbose=2
)

文档中的“为每个阈值生成一个度量值”语言也使我感到困惑。这不是实际发生的事。生成的度量实际上是您在列表中指定的所有阈值的度量的算术平均值。这里是一个二进制分类示例,其中对于True正值和召回,使用0.5决策阈值生成一个度量,为所提供的阈值列表生成一个度量。

代码语言:javascript
复制
import numpy as np

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

from tensorflow import keras
from tensorflow.keras import (
    callbacks,
    initializers,
    layers,
    metrics,
    optimizers,
)

from sklearn import __version__ as sk_version
from tensorflow import __version__ as tf_version
print(f"The sklearn version is {sk_version}.")
# The sklearn version is 1.0.2.
print(f"The tensorflow version is {tf_version}.")
# The tensorflow version is 2.7.0.

X1, Y1 = make_classification(
    n_samples=10**6, n_features=10, n_redundant=0, n_informative=4,
    n_clusters_per_class=1, n_classes=2, weights=[0.8, 0.2], random_state=42,
)

X_train, X_test, y_train, y_test = train_test_split(
    X1, Y1, test_size=10**6 // 3, random_state=42
)

model = keras.Sequential()
ki = initializers.RandomNormal(mean=0.0, stddev=0.05, seed=123)
model.add(
    layers.Dense(
        5,
        kernel_initializer=ki,
        bias_initializer="zeros",
        input_shape=(X_train.shape[1],),
        activation="relu",
    )
)
model.add(
    layers.Dense(
        1, activation="sigmoid"
    )
)

opt = optimizers.Adam(learning_rate=0.01)
thresh_list = list(np.arange(0.05, 1, 0.05))

model.compile(
    loss='binary_crossentropy',
    optimizer=opt,
    metrics=[
        metrics.TruePositives(thresholds=0.5, name="tp_0_5"),
        metrics.TruePositives(thresholds=thresh_list, name='tp_mult_thr'),
        metrics.FalsePositives(name="fp"),
        metrics.TrueNegatives(name="tn"),
        metrics.FalseNegatives(name="fn"),
        metrics.Recall(thresholds=0.5, name='recall_0_5'),
        metrics.Recall(
            thresholds=thresh_list,
            name="recall_mult_thr",
        ),
    ],
)

model.fit(
    X_train, y_train,
    validation_data=(X_test, y_test),
    batch_size=64, epochs=3, verbose=2
)

# Epoch 1/3
# 10417/10417 - 23s - loss: 0.0857 - tp_0_5: 123905.0000 - tp_mult_thr: 121617.1016 - 
# fp: 5823.0000 - tn: 525617.0000 - fn: 11322.0000 - recall_0_5: 0.9163 - 
# recall_mult_thr: 0.8994 - val_loss: 0.0809 - val_tp_0_5: 61862.0000 - 
# val_tp_mult_thr: 60766.5781 - val_fp: 2210.0000 - val_tn: 263435.0000 - 
# val_fn: 5826.0000 - val_recall_0_5: 0.9139 - val_recall_mult_thr: 0.8977 - 
# 23s/epoch - 2ms/step
# Epoch 2/3
# 10417/10417 - 24s - loss: 0.0806 - tp_0_5: 124853.0000 - tp_mult_thr: 122652.4766 - 
# fp: 5573.0000 - tn: 525867.0000 - fn: 10374.0000 - recall_0_5: 0.9233 - 
# recall_mult_thr: 0.9070 - val_loss: 0.0789 - val_tp_0_5: 62789.0000 - 
# val_tp_mult_thr: 61644.4219 - val_fp: 2889.0000 - val_tn: 262756.0000 - 
# val_fn: 4899.0000 - val_recall_0_5: 0.9276 - val_recall_mult_thr: 0.9107 - 
# 24s/epoch - 2ms/step
# Epoch 3/3
# 10417/10417 - 25s - loss: 0.0777 - tp_0_5: 125268.0000 - tp_mult_thr: 123142.8984 - 
# fp: 5556.0000 - tn: 525884.0000 - fn: 9959.0000 - recall_0_5: 0.9264 - 
# recall_mult_thr: 0.9106 - val_loss: 0.0781 - val_tp_0_5: 61261.0000 - 
# val_tp_mult_thr: 60321.4727 - val_fp: 1638.0000 - val_tn: 264007.0000 - 
# val_fn: 6427.0000 - val_recall_0_5: 0.9050 - val_recall_mult_thr: 0.8912 - 
# 25s/epoch - 2ms/step

0.5阈值(.9050)的验证集召回与预期的相同:

代码语言:javascript
复制
from sklearn.metrics import recall_score

pred_probs = model.predict(X_test)

for t in thresh_list:
    print(
        f"threshold: {t:.2f}, "
        f"recall: {recall_score(y_test, (pred_probs >= t).astype('int8'))}"
    )
# threshold: 0.05, recall: 0.9748847653941615
# threshold: 0.10, recall: 0.9671876846708427
# threshold: 0.15, recall: 0.9597565299609975
# threshold: 0.20, recall: 0.9526947169365323
# threshold: 0.25, recall: 0.9446726155300792
# threshold: 0.30, recall: 0.9371971398179885
# threshold: 0.35, recall: 0.9298546271126344
# threshold: 0.40, recall: 0.9214779576882165
# threshold: 0.45, recall: 0.9133967616121026
# threshold: 0.50, recall: 0.905049639522515  <---
# threshold: 0.55, recall: 0.8952990190284836
# threshold: 0.60, recall: 0.8855483985344522
# threshold: 0.65, recall: 0.8741283536225033
# threshold: 0.70, recall: 0.8619105306701336
# threshold: 0.75, recall: 0.8483039829807352
# threshold: 0.80, recall: 0.8315654177993145
# threshold: 0.85, recall: 0.8107640940787141
# threshold: 0.90, recall: 0.7819258952842454
# threshold: 0.95, recall: 0.7366002836544143

但是,多重阈值的验证集召回(.8912)是跨越所有阈值的召回的平均值:

代码语言:javascript
复制
np.mean([recall_score(y_test, (pred_probs >= t).astype('int8')) for t in thresh_list])
# 0.8911693902052141

这同样适用于真正的积极因素:

代码语言:javascript
复制
from sklearn.metrics import confusion_matrix

for t in thresh_list:
    print(
        f"threshold: {t:.2f}, "
        f"TPs: {confusion_matrix(y_test, (pred_probs >= t).astype('int8'))[1, 1]}"
    )
# threshold: 0.05, TPs: 65988
# threshold: 0.10, TPs: 65467
# threshold: 0.15, TPs: 64964
# threshold: 0.20, TPs: 64486
# threshold: 0.25, TPs: 63943
# threshold: 0.30, TPs: 63437
# threshold: 0.35, TPs: 62940
# threshold: 0.40, TPs: 62373
# threshold: 0.45, TPs: 61826
# threshold: 0.50, TPs: 61261 <---
# threshold: 0.55, TPs: 60601
# threshold: 0.60, TPs: 59941
# threshold: 0.65, TPs: 59168
# threshold: 0.70, TPs: 58341
# threshold: 0.75, TPs: 57420
# threshold: 0.80, TPs: 56287
# threshold: 0.85, TPs: 54879
# threshold: 0.90, TPs: 52927
# threshold: 0.95, TPs: 49859

和:

代码语言:javascript
复制
tp_list = list()
for t in thresh_list:
    tp_list.append(
        confusion_matrix(y_test, (pred_probs >= t).astype('int8'))[1, 1]
    )

print(f"Avg TPs across all thresholds: {np.mean(tp_list)}")
# Avg TPs across all thresholds: 60321.47368421053
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58023853

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档