我正在使用show_prediction包中的eli5函数来理解我的XGBoost分类器是如何得到预测的。由于某种原因,我似乎得到了一个回归得分,而不是我的模型的概率。
下面是一个具有公共数据集的完全可复制的示例。
from sklearn.datasets import load_breast_cancer
from xgboost import XGBClassifier
from sklearn.model_selection import train_test_split
from eli5 import show_prediction
# Load dataset
data = load_breast_cancer()
# Organize our data
label_names = data['target_names']
labels = data['target']
feature_names = data['feature_names']
features = data['data']
# Split the data
train, test, train_labels, test_labels = train_test_split(
features,
labels,
test_size=0.33,
random_state=42
)
# Define the model
xgb_model = XGBClassifier(
n_jobs=16,
eval_metric='auc'
)
# Train the model
xgb_model.fit(
train,
train_labels
)
show_prediction(xgb_model.get_booster(), test[0], show_feature_values=True, feature_names=feature_names)这给了我以下的结果。请注意3.7的分数,这绝对不是一个概率。

然而,官方的eli5 文档正确地显示了一个概率。

丢失的概率似乎与我使用xgb_model.get_booster()有关。看起来官方文档没有使用它,而是以-is传递模型,但是当我这样做时,我得到了TypeError: 'str' object is not callable,所以这似乎不是一种选择。
我还担心eli5没有通过遍历xgboost树来解释这个预测。看起来,我得到的“分数”实际上只是所有特性贡献的总和,就像我所期望的,如果eli5不是真正遍历树而是拟合一个线性模型的话。这是真的吗?我怎样才能让eli5穿越这棵树?
发布于 2018-12-14 17:31:50
解决了我自己的问题。根据这个吉特布问题的说法,eli5只支持较早版本的XGBoost (<=0.6)。我使用的是XGBoost版本0.80和eli5版本0.8。
发布该问题的解决方案:
import eli5
from xgboost import XGBClassifier, XGBRegressor
def _check_booster_args(xgb, is_regression=None):
# type: (Any, bool) -> Tuple[Booster, bool]
if isinstance(xgb, eli5.xgboost.Booster): # patch (from "xgb, Booster")
booster = xgb
else:
booster = xgb.get_booster() # patch (from "xgb.booster()" where `booster` is now a string)
_is_regression = isinstance(xgb, XGBRegressor)
if is_regression is not None and is_regression != _is_regression:
raise ValueError(
'Inconsistent is_regression={} passed. '
'You don\'t have to pass it when using scikit-learn API'
.format(is_regression))
is_regression = _is_regression
return booster, is_regression
eli5.xgboost._check_booster_args = _check_booster_args然后将我问题的最后一行代码片段替换为:
show_prediction(xgb_model, test[0], show_feature_values=True, feature_names=feature_names)解决了我的问题。
https://stackoverflow.com/questions/53783731
复制相似问题