首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何保存Tidymodels模型以供重用

如何保存Tidymodels模型以供重用
EN

Stack Overflow用户
提问于 2022-04-27 10:40:03
回答 1查看 573关注 0票数 3

下面是使用tidymodels模型创建lightgbm工作流的代码。但是,当我试图保存到.rds对象并预测

代码语言:javascript
复制
library(AmesHousing)
library(treesnip)
library(lightgbm)
library(tidymodels)
tidymodels_prefer()

### Model ###

# data
data <- make_ames() %>%
  janitor::clean_names()

data <- subset(data, select = c(sale_price, bedroom_abv_gr, bsmt_full_bath, bsmt_half_bath, enclosed_porch, fireplaces,
                                full_bath, half_bath, kitchen_abv_gr, garage_area, garage_cars, gr_liv_area, lot_area,
                                lot_frontage, year_built, year_remod_add, year_sold))

data$id <- c(1:nrow(data))

data <- data %>%
  mutate(id = as.character(id)) %>%
  select(id, everything())

# model specification

lgbm_model <- boost_tree(
  mtry = 7,
  trees = 347,
  min_n = 10,
  tree_depth = 12,
  learn_rate = 0.0106430579211173,
  loss_reduction = 0.000337948798058139,
) %>%
  set_mode("regression") %>%
  set_engine("lightgbm", objective = "regression")

# recipe and workflow

lgbm_recipe <- recipe(sale_price ~., data = data) %>%
  update_role(id, new_role = "ID") %>%
  step_corr(all_predictors(), threshold = 0.7) %>%
  prep()

lgbm_workflow <- workflow() %>% 
  add_recipe(lgbm_recipe) %>%
  add_model(lgbm_model)  
  
# fit workflow

fit_lgbm_workflow <- lgbm_workflow %>%
  fit(data = data)

# predict

data_predict <- subset(data, select = -c(sale_price))
predict(fit_lgbm_workflow, new_data = data_predict)


### CASE 1: Save the workflow with SaveRDS()

saveRDS(object = fit_lgbm_workflow, file = "lgbm_workflow.rds")
new_lgbm_workflow <- readRDS(file = "lgbm_workflow.rds")

# Predict - error: Attempting to use a Booster which no longer exists

predict(new_lgbm_workflow, new_data = data_predict)



### CASE 2: Save the workflow and the fitted model separately

fitted_model <- (fit_lgbm_workflow %>% extract_fit_parsnip())$fit
saveRDS(object = fit_lgbm_workflow, file = "lgbm_workflow.rds")
lightgbm::saveRDS.lgb.Booster(object = fitted_model, file = "lgbm_model.rds")


new_lgbm_workflow <- readRDS(file = "lgbm_workflow.rds")
new_lgbm_model <- lightgbm::readRDS.lgb.Booster(file = "lgbm_model.rds")
new_lgbm_workflow$fit$fit <- new_lgbm_model


# Predict - error: cannot predict on data of class ‘tbl_df’‘tbl’‘data.frame’

predict(new_lgbm_workflow, new_data = data_predict)

只有使用lightgbm模型的工作流似乎存在此问题。对于其他类型的模型(随机森林、xgboost、glm等),我可以使用saveRDS()保存已安装的工作流,用readRDS()读取并使用新的数据进行预测。

对于案例2,显然底层的预测函数将更改为predict.lgb.Booster(),后者以matrix作为输入。但是我的id变量是character格式,而matrix中的所有列都必须具有相同的格式。

有没有办法将整个workflow保存起来供将来使用?

EN

回答 1

Stack Overflow用户

发布于 2022-08-24 17:47:00

我想出了一种解决方案来保存lightgbm以供将来参考。它不使用tidymodel框架,而是被迫首先将其转换为lightgbm模型格式。如果您想要评估变量的重要性,情况也是如此。

根据上述守则:

代码语言:javascript
复制
# Convert to lightgbm booster model
lgb_model <- parsnip::extract_fit_engine(fit_lgbm_workflow) 

# If you want you can now evaluate variable importance. 
# Tidymodels does not support variable importance of lgb via bonsai currently

loss_varimp <- lgb_model %>%
    lgb.importance(.) 

# Save the booster out
lightgbm::lgb.save(lgb_model, filename_x)

# Read the booster in
lightgbm::lgb.load(filename_x)

我还没有弄清楚是否可以将加载的lightgbm合并回tidymodel格式,但现在您至少可以预测、使用和评估,而不必每次都重新运行模型。希望这有帮助,如果你找到一个更清洁/更当前的解决方案,请张贴!

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72027360

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档