首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >从混淆矩阵中(有效地)创建假的真值/预测值

从混淆矩阵中(有效地)创建假的真值/预测值
EN

Stack Overflow用户
提问于 2015-04-27 03:55:07
回答 1查看 1.1K关注 0票数 1

出于测试目的,我需要从混淆矩阵中创建假的真值/预测值。

我的混淆矩阵存储在Pandas DataFrame中,使用:

代码语言:javascript
复制
labels = ['N', 'L', 'R', 'A', 'P', 'V']
df = pd.DataFrame([
    [1971, 19, 1, 8, 0, 1],
    [16, 1940, 2, 23, 9, 10],
    [8, 3, 181, 87, 0, 11],
    [2, 25, 159, 1786, 16, 12],
    [0, 24, 4, 8, 1958, 6],
    [11, 12, 29, 11, 11, 1926] ], columns=labels, index=labels)
df.index.name = 'Actual'
df.columns.name = 'Predicted'

我假设索引是实际值,列是预测值。

这个混淆矩阵看起来像:

代码语言:javascript
复制
Predicted     N     L    R     A     P     V
Actual
N          1971    19    1     8     0     1
L            16  1940    2    23     9    10
R             8     3  181    87     0    11
A             2    25  159  1786    16    12
P             0    24    4     8  1958     6
V            11    12   29    11    11  1926

我正在寻找一种有效的方法来创建2个Numpy数组:y_truey_predict,这将产生这样一个混淆矩阵。

我的第一个想法是首先创建大小合适的Numpy数组。

所以我就这么做了:

代码语言:javascript
复制
N_all = df.sum().sum()

y_true = np.empty(N_all)
y_pred = np.empty(N_all)

但是我不知道如何有效地填充这2个Numpy数组

同样的代码也应该适用于二进制混淆矩阵,如:

代码语言:javascript
复制
labels = [False, True]
df = pd.DataFrame([
    [5, 3],
    [2, 7]], columns=labels, index=labels)
df.index.name = 'Actual'
df.columns.name = 'Predicted'

这个二进制混淆矩阵看起来像:

代码语言:javascript
复制
Predicted  False  True
Actual
False          5      3
True           2      7
EN

回答 1

Stack Overflow用户

发布于 2015-04-27 11:34:56

如果您想准确地重新创建,可以使用以下函数:

代码语言:javascript
复制
def create_arrays(df):
    # Unstack to make tuples of actual,pred,count
    df = df.unstack().reset_index()

    # Pull the value labels and counts
    actual = df['Actual'].values
    predicted = df['Predicted'].values
    totals = df.iloc[:,2].values

    # Use list comprehension to create original arrays
    y_true = [[curr_val]*n for (curr_val, n) in zip(actual, totals)]
    y_predicted = [[curr_val]*n for (curr_val, n) in zip(predicted, totals)]

    # They come nested so flatten them
    y_true = [item for sublist in y_true for item in sublist]
    y_predicted = [item for sublist in y_predicted for item in sublist]

    return y_true, y_predicted

我们可以检查这是否产生了预期的结果:

代码语言:javascript
复制
import pandas as pd
import numpy as np
from sklearn.metrics import confusion_matrix

labels = ['N', 'L', 'R', 'A', 'P', 'V']
df = pd.DataFrame([
    [1971, 19, 1, 8, 0, 1],
    [16, 1940, 2, 23, 9, 10],
    [8, 3, 181, 87, 0, 11],
    [2, 25, 159, 1786, 16, 12],
    [0, 24, 4, 8, 1958, 6],
    [11, 12, 29, 11, 11, 1926] ], columns=labels, index=labels)
df.index.name = 'Actual'
df.columns.name = 'Predicted'

# Recreate the original confusion matrix and check for equality
y_t, y_p = create_arrays(df)
conf_mat = confusion_matrix(y_t,y_p)
check_labels = np.unique(y_t)

df_new = pd.DataFrame(conf_mat, columns=check_labels, index=check_labels).loc[labels, labels]
df_new.index.name = 'Actual'
df_new.columns.name = 'Predicted'

df == df_new

输出:

代码语言:javascript
复制
Predicted     N     L     R     A     P     V
Actual                                       
N          True  True  True  True  True  True
L          True  True  True  True  True  True
R          True  True  True  True  True  True
A          True  True  True  True  True  True
P          True  True  True  True  True  True
V          True  True  True  True  True  True

对于二进制文件:

代码语言:javascript
复制
# And for the binary
labels = ['False', 'True']
df = pd.DataFrame([
    [5, 3],
    [2, 7]], columns=labels, index=labels)
df.index.name = 'Actual'
df.columns.name = 'Predicted'

# Recreate the original confusion matrix and check for equality
y_t, y_p = create_arrays(df)
conf_mat = confusion_matrix(y_t,y_p)
check_labels = np.unique(y_t)

df_new = pd.DataFrame(conf_mat, columns=check_labels, index=check_labels).loc[labels, labels]
df_new.index.name = 'Actual'
df_new.columns.name = 'Predicted'

df == df_new

Predicted False  True
Actual               
False      True  True
True       True  True
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/29882747

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档