首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何从SuperLearner模型中确定置换变量的重要性?

如何从SuperLearner模型中确定置换变量的重要性?
EN

Stack Overflow用户
提问于 2020-09-03 21:58:28
回答 1查看 258关注 0票数 0

我的总体目标是确定在波士顿数据集上执行的超级学习器的变量重要性。但是,当我尝试使用R中的VIP包确定变量重要性时,收到以下错误。我怀疑包含SuperLeaner对象的预测包装器是导致错误的原因,但我不确定。

代码语言:javascript
复制
# Call:  
# SuperLearner(Y = y_train, X = x_train, family = binomial(), SL.library =  # c("SL.mean",  
#    "SL.glmnet", "SL.ranger"), method = "method.AUC") 


#                    Risk      Coef
# SL.mean_All   0.55622189 0.3333333
# SL.glmnet_All 0.06240630 0.3333333
# SL.ranger_All 0.02745502 0.3333333
# Error in mean(actual == predicted, na.rm = FALSE): (list) object cannot be # coerced to type 'double'
# Traceback:

# 1. vi_permute(object = sl, method = "permute", feature_names = colnames, 
#  .     train = x_train, target = y_holdout, metric = "accuracy", 
#  .     type = "difference", nsim = 1, pred_wrapper = pred_wrapper)
# 2. vi_permute.default(object = sl, method = "permute", feature_names =    
#       colnames, 
#  .     train = x_train, target = y_holdout, metric = "accuracy", 
#  .     type = "difference", nsim = 1, pred_wrapper = pred_wrapper)
# 3. mfun(actual = train_y, predicted = pred_wrapper(object, newdata =  
#     train_x))
# 4. mean(actual == predicted, na.rm = FALSE)

我已经执行了以下操作:

代码语言:javascript
复制
library(MASS)
data(Boston, package = "MASS")

# Extract our outcome variable from the dataframe.
outcome = Boston$medv

# Create a dataframe to contain our explanatory variables.
data = subset(Boston, select = -medv)

set.seed(1)
# Reduce to a dataset of 150 observations to speed up model fitting.
train_obs = sample(nrow(data), 150)

# X is our training sample.
x_train = data[train_obs, ]

# Create a holdout set for evaluating model performance.
x_holdout = data[-train_obs, ]

# Create a binary outcome variable: towns in which median home value is > 22,000.
outcome_bin = as.numeric(outcome > 22)

y_train = outcome_bin[train_obs]
y_holdout = outcome_bin[-train_obs]

library(SuperLearner)
set.seed(1)
sl = SuperLearner(Y = y_train, X = x_train, family = binomial(),
  SL.library = c("SL.mean", "SL.glmnet", "SL.ranger"), method = "method.AUC")
sl

colnames <- colnames(x_train)
pred_wrapper <- function(sl, newdata) {
  predict(sl, x = as.matrix(y_holdout)) %>%
    as.vector()
}

# Plot VI scores
library(vip)
p1 <- vi_permute(object = sl, method = "permute", feature_names = colnames, train = x_train, 
          target = y_holdout,
          metric = "accuracy",
          type = "difference", 
          nsim = 1,
          pred_wrapper = pred_wrapper) 
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-09-05 03:51:28

对于SuperLearner对象,您可以看到它返回了一个概率列表

代码语言:javascript
复制
predict(sl,x_train[1:2,])
$pred
          [,1]
[1,] 0.4049966
[2,] 0.1905551

$library.predict
     SL.mean_All SL.glmnet_All SL.ranger_All
[1,]   0.3866667     0.5718232        0.2565
[2,]   0.3866667     0.1082986        0.0767

如果你读过维格奈特(?predict.SuperLearner),我猜你想从超级学习者那里得到预测。因此,更改函数以提取概率并将其转换为标签:

代码语言:javascript
复制
pred_wrapper <- function(sl, newdata) {
  ifelse(predict(sl,newdata)$pred>0.5,1,0)
}

我们简单地检查一下结果:

代码语言:javascript
复制
table(pred_wrapper(sl,x_holdout),y_holdout)
   y_holdout
      0   1
  0 183  39
  1   9 125

并使用x_holdout作为训练:

代码语言:javascript
复制
p1 <- vi_permute(object = sl, method = "permute", feature_names = colnames, train = x_holdout, 
          target = y_holdout,
          metric = "accuracy",
          type = "difference", 
          nsim = 5,
          pred_wrapper = pred_wrapper) 

# A tibble: 13 x 3
   Variable Importance   StDev
   <chr>         <dbl>   <dbl>
 1 crim       0.00337  0.00126
 2 zn        -0.000562 0.00235
 3 indus      0.00337  0.00235
 4 chas       0.00674  0.00377
 5 nox        0.00225  0.00235
 6 rm         0.0315   0.0165 
 7 age        0.0213   0.0108 
 8 dis        0.0129   0.00944
 9 rad       -0.00169  0.00377
10 tax        0.00506  0.00126
11 ptratio    0.0174   0.0145 
12 black     -0.00281  0      
13 lstat      0.241    0.0204
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63725192

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档