首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >混淆矩阵没有给出神经网络sklearn Python的全部数据

混淆矩阵没有给出神经网络sklearn Python的全部数据
EN

Stack Overflow用户
提问于 2021-06-25 03:41:36
回答 1查看 28关注 0票数 0

我只是想得到我的整个数据集的混淆矩阵。

代码语言:javascript
复制
import pandas as pd

from sklearn.model_selection import train_test_split
from sklearn import metrics
from neupy import algorithms

df = pd.read_csv('my_data.csv', header=None)

df = df.rename(columns={0: 'season_at_test', 
                        1: 'age',
                        2: 'child',
                        3: 'trauma',
                        4: 'surgery',
                        5: 'fever',
                        6: 'alcohol',
                        7: 'smoking','
                       })

df['smoking'] = df['smoking'].map({'N': 1, 'O':0})

data = df.iloc[:, :-1]
target = df['diagnosis']

X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2, random_state=303)


pnn = algorithms.PNN(std=10, verbose=False)

pnn.train(X_train, y_train)

y_pred = pnn.predict(X_test)

print("Accuracy:",metrics.accuracy_score(y_test, y_pred))

pnn = algorithms.PNN(std=10, verbose=False)

pnn.train(X_train, y_train)

y_pred = pnn.predict(X_test)

print("Accuracy:",metrics.accuracy_score(y_test, y_pred))

metrics.confusion_matrix(y_test, y_pred)

它给了我;

代码语言:javascript
复制
Accuracy: 0.7
array([[ 0,  1],
       [ 5, 14]], dtype=int64)

此输出。此外,我需要运行2次工作,它给了我一个错误,在第一次运行。我的混淆矩阵应该类似于下面的结果,因为我有100个样本,而不是20个。

代码语言:javascript
复制
[ 58,  30]
[ 5, 7]

如果我尝试添加像这样的东西

代码语言:javascript
复制
y_pred = pnn.predict(X_test)
x_pred = pnn.predict(data)

metrics.confusion_matrix(x_pred, y_test)

它给出了"ValueError:找到样本数量不一致的输入变量: 100,20“

我如何才能对我的所有数据都起作用呢?我希望我的100个样本都有一个混淆矩阵。

EN

回答 1

Stack Overflow用户

发布于 2021-06-25 04:04:40

当您在有100个条目的data上执行train_test_split时,您将数据拆分为两部分(训练和测试)。根据test_size=0.2的定义,其中20%进入测试集,因此y_test将拥有20%的data,这相当于20个条目。这是意料之中的行为。

如果你想得到整个数据集的混淆矩阵,你应该这样做:

代码语言:javascript
复制
y_pred = pnn.predict(data)
metrics.confusion_matrix(target, y_pred)

这种方法是而不是推荐的,因为它不能代表模型在不可见数据上的真实性能。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68121910

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档