我正在编写本教程https://medium.com/@sarahcy/read-this-how-winners-create-life-changing-habits-that-actually-work-atomic-habits-by-james-ac7a3c6df911,目前我正在尝试运行模型评估部分
class_labels = list(set(labels))
meu.display_model_performance_metrics(true_labels=y_test, predicted_labels=predictions, classes=class_labels)我得到了这个错误
Prediction Confusion Matrix:
------------------------------
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/blabla/XAI/model_evaluation_utils.py", line 87, in display_model_performance_metrics
classes=classes)
File "/blabla/XAI/model_evaluation_utils.py", line 62, in display_confusion_matrix
labels=level_labels),
TypeError: __new__() got an unexpected keyword argument 'labels'我该如何解决这个问题?
BR
编辑:以下是问题行之前的代码:
# part1
import pandas as pd
import numpy as np
import model_evaluation_utils as meu
import matplotlib.pyplot as plt
from collections import Counter
import shap
import eli5
import warnings
warnings.filterwarnings('ignore')
plt.style.use('fivethirtyeight')
shap.initjs()
#part 2
data, labels = shap.datasets.adult(display=True)
labels = np.array([int(label) for label in labels])
print(data.shape , labels.shape)
data.head()
#part 3
Counter(labels)
#part 4
cat_cols = data.select_dtypes(['category']).columns
data[cat_cols] = data[cat_cols].apply(lambda x: x.cat.codes)
data.head()
#part 5
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.3, random_state=42)
print(X_train.shape, X_test.shape)
X_train.head(3)
data_disp, labels_disp = shap.datasets.adult(display=True)
X_train_disp, X_test_disp, y_train_disp, y_test_disp = train_test_split(data_disp, labels_disp, test_size=0.3, random_state=42)
print(X_train_disp.shape, X_test_disp.shape)
X_train_disp.head(3)
#part 6
import xgboost as xgb
xgc = xgb.XGBClassifier(n_estimators=500, max_depth=5, base_score=0.5,
objective='binary:logistic', random_state=42)
xgc.fit(X_train, y_train)
#part 7
predictions = xgc.predict(X_test)
predictions[:10]
#part 8
class_labels = list(set(labels))
meu.display_model_performance_metrics(true_labels=y_test, predicted_labels=predictions, classes=class_labels)发布于 2021-10-14 10:10:01
删除此函数:
def display_confusion_matrix(true_labels, predicted_labels, classes=[1,0]):
total_classes = len(classes)
level_labels = [total_classes*[0], list(range(total_classes))]
cm = metrics.confusion_matrix(y_true=true_labels,
y_pred=predicted_labels, labels=classes)
cm_frame = pd.DataFrame(data=cm,
columns=pd.MultiIndex(levels=[['Predicted:'],
classes],
labels=level_labels),
index=pd.MultiIndex(levels=[['Actual:'], classes],
labels=level_labels))
print(cm_frame) 它的内容
然后转到notebook:在代码的最后一行:
meu.display_model_performance_metrics(true_labels=y_test,predicted_labels=predictions, classes=class_labels)您会注意到想要的函数是"display_model_performance_metrics“
=所以我们将回到我们的model_evaluation_utils
我们将转到函数:
def display_model_performance_metrics(true_labels, predicted_labels, classes=[1,0]):
print('Model Performance metrics:')
print('-'*30)
get_metrics(true_labels=true_labels, predicted_labels=predicted_labels)
print('\nModel Classification report:')
print('-'*30)
display_classification_report(true_labels=true_labels, predicted_labels=predicted_labels,
classes=classes)
print('\nPrediction Confusion Matrix:')
print('-'*30)
display_confusion_matrix(true_labels=true_labels, predicted_labels=predicted_labels,
classes=classes)去掉最后一行:
display_confusion_matrix(true_labels=true_labels,predicted_labels=predicted_labels,classes=classes)变得更简单,就像这样:
def display_model_performance_metrics(true_labels, predicted_labels, classes=[1,0]):
print('Model Performance metrics:')
print('-'*30)
get_metrics(true_labels=true_labels, predicted_labels=predicted_labels)
print('\nModel Classification report:')
print('-'*30)
display_classification_report(true_labels=true_labels, predicted_labels=predicted_labels,
classes=classes)
print('\nPrediction Confusion Matrix:')
print('-'*30)最后,如果你想使用:
display_confusion_matrix(true_labels=true_labels,predicted_labels=predicted_labels,classes=classes)
您可以直接在您的笔记本中使用它,而不是file.py或python文件:)
https://stackoverflow.com/questions/60957452
复制相似问题