首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >ValueError: classification_report的未知标签类型

ValueError: classification_report的未知标签类型
EN

Stack Overflow用户
提问于 2019-06-30 23:56:08
回答 1查看 478关注 0票数 0

我正在尝试使用sklean包的classification_report模块来评估多类分类模型。

Y_pred维度:(1000,36) y_test维度:(1000,36)

我尝试在两个阵列上调用classification_report,即y_test和y_pred

代码语言:javascript
复制
def display_results(y_test,y_pred,column_name=labels):
    print(classification_report(y_test,y_pred,target_names=labels))

使用下面的代码,我得到:

代码语言:javascript
复制
ValueError: Unknown label type: (array([[1, 0, 0, ..., 0, 0, 0],
       [1, 0, 0, ..., 0, 0, 0],
       [1, 0, 0, ..., 1, 1, 0],
       ...,
       [1, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [1, 0, 0, ..., 0, 0, 0]]), array([[1, 0, 0, ..., 0, 0, 0],
       [1, 0, 0, ..., 0, 0, 0],
       [1, 0, 0, ..., 0, 0, 0],
       ...,
       [1, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [1, 0, 0, ..., 0, 0, 0]]))

我希望得到基于传递给函数的标签的精确度、召回率、F1和所有列的总平均指标。

EN

回答 1

Stack Overflow用户

发布于 2019-10-28 01:10:05

对于您的错误,需要np.hstack

在分类器具有多类多输出的情况下效果最佳

代码语言:javascript
复制
from sklearn.utils.multiclass import type_of_target
type_of_target(y_test)
type_of_target(y_pred)

>>'multiclass-multioutput'

所以,你的解决方案是

代码语言:javascript
复制
print(classification_report(np.hstack(y_test),np.hstack(y_pred)))
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/56826216

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档