首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >R中带有潮汐模型的catboost模型的总结形状图

R中带有潮汐模型的catboost模型的总结形状图
EN

Stack Overflow用户
提问于 2022-03-05 05:05:55
回答 2查看 476关注 0票数 2

我正试图在tidymodel框架内构建一个catboost模型。下面给出了最小可重现性的例子。我可以使用DALEXmodelStudio软件包来获得模型解释,但我想为这个catboost模型创建VIP情节、像这样和概要图形像这样。我试过像fastshapSHAPforxgboost这样的包,没有任何运气。我意识到,我必须从model对象中提取变量重要性和形状值,并使用它们生成这些图,但不知道如何做到这一点。有没有办法在R里完成这件事?

代码语言:javascript
复制
library(tidymodels)
library(treesnip)
library(catboost)
library(modelStudio)
library(DALEXtra)
library(DALEX)

data <- structure(list(Age = c(74, 60, 57, 53, 72, 72, 71, 77, 50, 66), StatusofNation0developed = structure(c(2L, 2L, 2L, 2L, 2L, 
                                                                                                       1L, 2L, 1L, 1L, 2L), .Label = c("0", "1"), class = "factor"), 
               treatment = structure(c(2L, 1L, 2L, 2L, 2L, 1L, 1L, 3L, 1L, 
                                       2L), .Label = c("0", "1", "2"), class = "factor"), InHospitalMortalityMortality = c(0, 
                                                                                                                           0, 1, 1, 1, 0, 0, 1, 1, 0)), row.names = c(NA, 10L), class = "data.frame")
split <- initial_split(data, strata = InHospitalMortalityMortality)
train <- training(split)
test <- testing(split)

train$InHospitalMortalityMortality <- as.factor(train$InHospitalMortalityMortality)

rec <- recipe(InHospitalMortalityMortality ~ ., data = train)

clf <- boost_tree() %>%
  set_engine("catboost") %>%
  set_mode("classification")

wflow <- workflow() %>%
  add_recipe(rec) %>%
  add_model(clf)

model <- wflow %>% fit(data = train)

explainer <- explain_tidymodels(model,
                                data = test,
                                y = test$InHospitalMortalityMortality,
                                label = "catboost")
new_observation <- test[1:2,]
modelStudio(explainer, new_observation)
EN

回答 2

Stack Overflow用户

发布于 2022-04-20 20:17:51

上面的链接提供了一个答案,但它是不完整的。在这里,它是按照相同的工作流完成的。

如前所述:首先,安装R包{reticulate}和和{网状}。接下来,为python使用{网状}设置一个虚拟环境。在使用RStudio时,设置虚拟环境相对简单。请检查他们的参考资料一步一步的指示。

然后,pip在venv中安装{shap}和{ matplotlib } --注意,matplotlib 3.2.2对于摘要图似乎是必要的(更详细的信息请参见GitHub问题)。

工作流(来自treesnip文档):

代码语言:javascript
复制
library(tidymodels)
library(treesnip)

data("diamonds", package = "ggplot2")
diamonds <- diamonds %>% sample_n(1000)

#vfolds resamples 
diamond_splits <- vfold_cv(diamonds, v = 5)

model_spec <- boost_tree(mtry = 5, trees = 500) %>% set_mode("regression")

#model specifications
lightgbm_model <- model_spec %>% 
    set_engine("lightgbm", nthread = 4)

#workflow
lightgbm_workflow <- workflow() %>% 
    add_model(lightgbm_model)

rec_ordered <- recipe( 
    price ~ .
    ,data = diamonds
)

lightgbm_fit_ordered <- fit_resamples( 
    add_recipe(
        lightgbm_workflow, rec_ordered
    ), resamples = diamond_splits
)

配合工作流程:

代码语言:javascript
复制
fit_lightgbm_workflow <- lightgbm_workflow %>%
    add_recipe(rec_ordered) %>%
    fit(data = diamonds)

使用fit工作流,我们现在可以通过{reticulate}创建shap值,并用{reticulate}和{网状}绘图。

首先,力图:要做到这一点,我们需要为pred_wrapper参数创建一个预测函数。

代码语言:javascript
复制
predict_function_gbm <- function(model, newdata){
    predict(model, newdata) %>% pull(., 1) # 
}

现在我们需要基线参数的平均预测值。

代码语言:javascript
复制
mean_preds <- mean( 
    predict_function_gbm(
      fit_lightgbm_workflow, diamonds %>% select(-price)
      ) 
)

在这里,创建shap值:

代码语言:javascript
复制
fastshap::explain( 
  fit_lightgbm_workflow, 
  X = as.data.frame(diamonds %>% select(-price)), 
  pred_wrapper = predict_function_gbm, 
  nsim= 10
) -> gbm_explained

现在,关于作用力图:

代码语言:javascript
复制
fastshap::force_plot( 
  object = gbm_explained[1, ],
  feature_values = as.data.frame(diamonds %>% select(-price))[1, ],
  display = "viewer", # or "html" depending on rendering preference
  baseline = mean_preds
)

# For classification, add: link = "logit"
# For vertical stacking, change: [1, ] to [1:20, ] for example. 
# this may or may not throw error depending on version of shap used.
# see {fastshap} issues.

现在,对于摘要图:使用{网状}直接访问函数:

代码语言:javascript
复制
library(reticulate)
shap = import("shap")
np = import("numpy")

shap$summary_plot( 
  data.matrix(gbm_explained), 
  data.matrix(diamonds %>% select(-price))
)

例如,依赖情节也是如此。

代码语言:javascript
复制
shap$dependence_plot( 
  "rank(1)",
  data.matrix(gbm_explained), 
  data.matrix(diamonds %>% select(-price))
)

最后注意:重复渲染会导致错误的可视化。在dependence_plot中直接命名一个特性(即“剪切”)给我带来了一个错误。

票数 4
EN

Stack Overflow用户

发布于 2022-04-27 04:23:30

首先,我们需要从模型对象中提取工作流并使用它来预测测试集。(可选)使用catboost.load_pool函数创建池对象。

代码语言:javascript
复制
predict(model$.workflow[[1]], test[])
pool = catboost.load_pool(dataset, label = label_values, cat_features = NULL)

然后利用catboost.get_feature_importance函数得到模型对象的特征重要性评分。

代码语言:javascript
复制
catboost.get_feature_importance(extract_fit_engine(model),
                                pool = NULL,
                                type = 'FeatureImportance',
                                thread_count = -1)

然后,我们可以使用函数type = 'ShapValues'选项获得shapvalue。

代码语言:javascript
复制
shapvalue <- catboost.get_feature_importance(extract_fit_engine(model),
                                             pool = pool,
                                             type = 'ShapValues',
                                             thread_count = -1)
shapvalue <- data.frame(shapvalue)
shap_long_game <- shap.prep(shap_contrib = shapvalue, X_train = dataset)

最后绘制形状值

代码语言:javascript
复制
shap_summplot <- shap.plot.summary(shap_long_game, scientific = F) 
shap_summplot + 
  scale_y_continuous(labels = comma)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71359666

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档