首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在一幅图中绘制多个查准率曲线

在一幅图中绘制多个查准率曲线
EN

Data Science用户
提问于 2020-09-08 14:34:26
回答 1查看 4.1K关注 0票数 3

我有一个不平衡的数据集,我正在阅读这个文章,它检查SMOTE和RUS以解决不平衡问题。因此,我定义了以下三个模型:

代码语言:javascript
复制
    # AdaBoost
    ada = AdaBoostClassifier(n_estimators=100, random_state=42)
    ada.fit(X_train,y_train)
    y_pred_baseline = ada.predict(X_test) 
    
    # SMOTE    
    sm = SMOTE(random_state=42)
    X_train_sm, y_train_sm = sm.fit_sample(X_train, y_train)
    ada_sm = AdaBoostClassifier(n_estimators=100, random_state=42)
    ada_sm.fit(X_train_sm,y_train_sm)
    y_pred_sm = ada_sm.predict(X_test) 
    
    #RUS
    rus = RandomUnderSampler(random_state=42)
    X_train_rus, y_train_rus = rus.fit_resample(X, y)
    ada_rus = AdaBoostClassifier(n_estimators=100, random_state=42)
    ada_rus.fit(X_train_rus,y_train_rus)
    y_pred_rus = ada_rus.predict(X_test) 

然后,我绘制了这3种型号的精确召回曲线。我选择了这条曲线,因为我想想象模型的表现,我对真正的负面(负面类是多数类)不太感兴趣。

为了绘制曲线,我使用了山猫学习的plot_precision_recall_curve方法,如下所示:

代码语言:javascript
复制
    from sklearn.metrics import precision_recall_curve
    from sklearn.metrics import plot_precision_recall_curve
    import matplotlib.pyplot as plt
    
    disp = plot_precision_recall_curve(ada, X_test, y_test)
    disp.ax_.set_title('Precision-Recall curve')

    disp = plot_precision_recall_curve(ada_sm, X_test, y_test)
    disp.ax_.set_title('Precision-Recall curve')

    disp = plot_precision_recall_curve(ada_rus, X_test, y_test)
    disp.ax_.set_title('Precision-Recall curve')

这导致了3个单独的地块。

然而,我想有这3条曲线在一个图表,以便他们可以很容易地比较。所以我想要一个类似于文章中的情节:

但我不知道如何做到这一点,因为plot_precision_recall_curve方法只使用一个分类器作为输入。

如果能提供一些帮助,我们将不胜感激。

EN

回答 1

Data Science用户

回答已采纳

发布于 2020-09-08 17:54:27

尝试以这种方式使用Matplotlib gca()方法,您可以指示要绘制的轴。

代码语言:javascript
复制
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import plot_precision_recall_curve
import matplotlib.pyplot as plt

plot_precision_recall_curve(ada, X_test, y_test, ax = plt.gca(),name = "AdaBoost")

plot_precision_recall_curve(ada_sm, X_test, y_test, ax = plt.gca(),name = "SMOTE")

plot_precision_recall_curve(ada_rus, X_test, y_test, ax = plt.gca(),name = "RUS")

plt.title('Precision-Recall curve')
票数 3
EN
页面原文内容由Data Science提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://datascience.stackexchange.com/questions/81389

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档