首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Tensorflow估计器:切换到careful_interpolation以得到模型的正确的PR-AUC

Tensorflow估计器:切换到careful_interpolation以得到模型的正确的PR-AUC
EN

Stack Overflow用户
提问于 2018-06-14 05:34:12
回答 1查看 2.6K关注 0票数 4

在我的项目中,我使用的是预先制作的估值器DNNClassifier。这是我的估计数:

代码语言:javascript
复制
model = tf.estimator.DNNClassifier(
        hidden_units=network,
        feature_columns=feature_cols,
        n_classes= 2,
        activation_fn=tf.nn.relu,
        optimizer=tf.train.ProximalAdagradOptimizer(
            learning_rate=0.1,
            l1_regularization_strength=0.001
        ),
        config=chk_point_run_config,
        model_dir=MODEL_CHECKPOINT_DIR
    )

当我使用eval_res = model.evaluate(..)对模型进行评估时,我会得到以下警告:

警告:tensorflow:已知梯形规则会产生错误的PR-AUCs;请改为"careful_interpolation“。

如何切换到careful_interpolation以从evaluate()方法获得正确的结果?

Tensorflow版本:1.8

EN

回答 1

Stack Overflow用户

发布于 2018-07-11 15:01:13

不幸的是,使用预先制定的估计器,几乎没有定制评估过程的自由。目前,DNNClassifier似乎没有提供一种方法来调整评估指标,其他估计器也是如此。

虽然不太理想,但一种解决方案是使用tf.contrib.metrics.add_metrics来使用所需的度量来增强估计器,如果将完全相同的密钥分配给新的度量,则将替换旧的度量:

如果在此和现有的评估器之间存在名称冲突,这将覆盖现有的度量。

它的优点是为产生概率预测的任何估计器工作,而牺牲的是仍然计算每个评估的覆盖度量。DNNClassifier估计器在键'logistic' (罐装估计器中的可能密钥列表为这里)下提供逻辑值(介于0到1之间)。对于其他估值头来说,情况可能并不总是这样,但也可能有其他选择:在用tf.contrib.estimator.multi_label_head构建的多标签分类器中,logistic是不可用的,但是可以使用probabilities

因此,代码将如下所示:

代码语言:javascript
复制
def metric_auc(labels, predictions):
    return {
        'auc_precision_recall': tf.metrics.auc(
            labels=labels, predictions=predictions['logistic'], num_thresholds=200,
            curve='PR', summation_method='careful_interpolation')
    }

estimator = tf.estimator.DNNClassifier(...)
estimator = tf.contrib.estimator.add_metrics(estimator, metric_auc)

评估时,仍然会出现警告信息,但随后将调用带有仔细插值的AUC。将此度量分配给不同的键也将允许您检查两个求和方法之间的差异。我对多标签逻辑回归任务的测试表明,测量值可能确实略有不同: auc_precision_recall = 0.05173396,auc_precision_recall_careful = 0.05059402。

还有一个原因,为什么默认的求和方法仍然是'trapezoidal',尽管文档建议谨慎的插值是“严格的首选”。作为在拉请求#19079中进行注释,更改将明显向后不兼容。随后对同一拉请求的评论建议上述解决办法。

票数 7
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50850258

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档