下面是使用tidymodels模型创建lightgbm工作流的代码。但是,当我试图保存到.rds对象并预测
library(AmesHousing)
library(treesnip)
library(lightgbm)
library(tidymodels)
tidymodels_prefer()
### Model ###
# data
data <- make_ames() %>%
janitor::clean_names()
data <- subset(data, select = c(sale_price, bedroom_abv_gr, bsmt_full_bath, bsmt_half_bath, enclosed_porch, fireplaces,
full_bath, half_bath, kitchen_abv_gr, garage_area, garage_cars, gr_liv_area, lot_area,
lot_frontage, year_built, year_remod_add, year_sold))
data$id <- c(1:nrow(data))
data <- data %>%
mutate(id = as.character(id)) %>%
select(id, everything())
# model specification
lgbm_model <- boost_tree(
mtry = 7,
trees = 347,
min_n = 10,
tree_depth = 12,
learn_rate = 0.0106430579211173,
loss_reduction = 0.000337948798058139,
) %>%
set_mode("regression") %>%
set_engine("lightgbm", objective = "regression")
# recipe and workflow
lgbm_recipe <- recipe(sale_price ~., data = data) %>%
update_role(id, new_role = "ID") %>%
step_corr(all_predictors(), threshold = 0.7) %>%
prep()
lgbm_workflow <- workflow() %>%
add_recipe(lgbm_recipe) %>%
add_model(lgbm_model)
# fit workflow
fit_lgbm_workflow <- lgbm_workflow %>%
fit(data = data)
# predict
data_predict <- subset(data, select = -c(sale_price))
predict(fit_lgbm_workflow, new_data = data_predict)
### CASE 1: Save the workflow with SaveRDS()
saveRDS(object = fit_lgbm_workflow, file = "lgbm_workflow.rds")
new_lgbm_workflow <- readRDS(file = "lgbm_workflow.rds")
# Predict - error: Attempting to use a Booster which no longer exists
predict(new_lgbm_workflow, new_data = data_predict)
### CASE 2: Save the workflow and the fitted model separately
fitted_model <- (fit_lgbm_workflow %>% extract_fit_parsnip())$fit
saveRDS(object = fit_lgbm_workflow, file = "lgbm_workflow.rds")
lightgbm::saveRDS.lgb.Booster(object = fitted_model, file = "lgbm_model.rds")
new_lgbm_workflow <- readRDS(file = "lgbm_workflow.rds")
new_lgbm_model <- lightgbm::readRDS.lgb.Booster(file = "lgbm_model.rds")
new_lgbm_workflow$fit$fit <- new_lgbm_model
# Predict - error: cannot predict on data of class ‘tbl_df’‘tbl’‘data.frame’
predict(new_lgbm_workflow, new_data = data_predict)只有使用lightgbm模型的工作流似乎存在此问题。对于其他类型的模型(随机森林、xgboost、glm等),我可以使用saveRDS()保存已安装的工作流,用readRDS()读取并使用新的数据进行预测。
对于案例2,显然底层的预测函数将更改为predict.lgb.Booster(),后者以matrix作为输入。但是我的id变量是character格式,而matrix中的所有列都必须具有相同的格式。
有没有办法将整个workflow保存起来供将来使用?
发布于 2022-08-24 17:47:00
我想出了一种解决方案来保存lightgbm以供将来参考。它不使用tidymodel框架,而是被迫首先将其转换为lightgbm模型格式。如果您想要评估变量的重要性,情况也是如此。
根据上述守则:
# Convert to lightgbm booster model
lgb_model <- parsnip::extract_fit_engine(fit_lgbm_workflow)
# If you want you can now evaluate variable importance.
# Tidymodels does not support variable importance of lgb via bonsai currently
loss_varimp <- lgb_model %>%
lgb.importance(.)
# Save the booster out
lightgbm::lgb.save(lgb_model, filename_x)
# Read the booster in
lightgbm::lgb.load(filename_x)我还没有弄清楚是否可以将加载的lightgbm合并回tidymodel格式,但现在您至少可以预测、使用和评估,而不必每次都重新运行模型。希望这有帮助,如果你找到一个更清洁/更当前的解决方案,请张贴!
https://stackoverflow.com/questions/72027360
复制相似问题