首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Scikit平均精度分数输入形状差

Scikit平均精度分数输入形状差
EN

Stack Overflow用户
提问于 2018-04-12 13:46:53
回答 1查看 500关注 0票数 0

我正在绘制一条精确/回忆分数曲线。这是我的代码:

代码语言:javascript
复制
    lbl_enc = preprocessing.LabelEncoder()
    labels = lbl_enc.fit_transform(test_tags)

    y_score = clf.predict_proba(test_set)

    average_precision = average_precision_score(labels, y_score)
    print('Average precision-recall score: {0:0.2f}'.format(average_precision))

    precision, recall, _ = precision_recall_curve(labels, y_score)

    plt.step(recall, precision, color='b', alpha=0.2,
             where='post')
    plt.fill_between(recall, precision, step='post', alpha=0.2,
                     color='b')

    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.ylim([0.0, 1.05])
    plt.xlim([0.0, 1.0])
    plt.title('2-class Precision-Recall curve: Average P-R = {0:0.2f}'.format(
        average_precision))

在计算average_precision_score时,我得到了由"y_score“变量引起的"ValueError:坏输入形状(119,2)”。

y_score的格式如下:

代码语言:javascript
复制
array([[0.45953712, 0.54046288],
   [0.78289908, 0.21710092],
   [0.13488789, 0.86511211],
   [0.56162583, 0.43837417],
   (...)
   [0.4595595 , 0.5404405 ]])

而标签是这样的:

代码语言:javascript
复制
array([0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
   1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
   1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
   1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
   1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
   1, 1, 1, 1, 1, 1, 1, 1, 1])

我如何使这项工作,以计算avg的精度得分?提前谢谢。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2018-04-12 14:11:37

文档中,它说:

y_score :数组,shape = n_samples或n_samples,n_classes 目标分数,既可以是概率估计的正类、置信度值,也可以是非阈值的决策测度(由“decision_function”对某些分类器返回)。

因此,我相信你只需要做:

代码语言:javascript
复制
average_precision  = average_precision_score(labels, y_score[:,1])
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/49798269

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档