首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >绘制阈值(precision_recall曲线) matplotlib/sklearn.metrics

绘制阈值(precision_recall曲线) matplotlib/sklearn.metrics
EN

Stack Overflow用户
提问于 2021-01-31 12:44:57
回答 2查看 1.2K关注 0票数 3

我正在尝试绘制我的查准率/召回率曲线的阈值。我只是使用了MNSIT的数据,示例来自于“使用scikit学习机器学习-学习,keras和TensorFlow”这本书。尝试训练模型来检测5的图像。我不知道你需要看到多少代码。我已经为训练集建立了混淆矩阵,并计算了精确度和召回值,以及阈值。我已经绘制了pre/rec曲线,书中的示例说明要添加轴标签、ledged、网格和高亮显示阈值,但在书中我在下面放置了一个星号的代码被删掉了。除了如何将阈值显示在图中之外,我能够计算出所有的阈值。我已经在书中添加了一张图表与我所拥有的图表的图片。这就是这本书所展示的:

vs我的图表:

我不能让有两个阈值的红色点线出现。有人知道我会怎么做吗?下面是我的代码:

代码语言:javascript
复制
from sklearn.metrics import precision_recall_curve

precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)

def plot_precision_recall_vs_thresholds(precisions, recalls, thresholds):
    plt.plot(thresholds, precisions[:-1], "b--", label="Precision")
    plt.plot(thresholds, recalls[:-1], "g--", label="Recall")
    plt.xlabel("Threshold")
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
    plt.grid(b=True, which="both", axis="both", color='gray', linestyle='-', linewidth=1)

plot_precision_recall_vs_thresholds(precisions, recalls, thresholds)
plt.show()

我知道这里有相当多关于sklearn的问题,但似乎没有人能覆盖到红线的出现。我将非常感谢你的帮助!

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2021-01-31 13:44:10

可以使用以下代码绘制水平线和垂直线:

代码语言:javascript
复制
plt.axhline(y_value, c='r', ls=':')
plt.axvline(x_value, c='r', ls=':')
票数 3
EN

Stack Overflow用户

发布于 2021-04-06 03:59:19

这应该以正确的方式工作:

代码语言:javascript
复制
def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
    recall_80_precision = recalls[np.argmax(precisions >= 0.80)]
    threshold_80_precision = thresholds[np.argmax(precisions >= 0.80)]
    
    plt.plot(thresholds, precisions[:-1], "b--", label="Precision", linewidth=2)
    plt.plot(thresholds, recalls[:-1], "g-", label="Recall", linewidth=2)
    plt.xlabel("Threshold")
    plt.plot([threshold_80_precision, threshold_80_precision], [0., 0.8], "r:")
    plt.axis([-4, 4, 0, 1])
    plt.plot([-4, threshold_80_precision], [0.8, 0.8], "r:")
    plt.plot([-4, threshold_80_precision], [recall_80_precision, recall_80_precision], "r:")
    plt.plot([threshold_80_precision], [0.8], "ro") 
    plt.plot([threshold_80_precision], [recall_80_precision], "ro")
    plt.grid(True)
    plt.legend()
    plt.show()

我在尝试复制本书中的代码时遇到了这段代码。原来@ageron把所有的资源都放在了他的github页面上。你可以在here上查看

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65975815

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档