首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何正确保存mlr3 lightgbm模型?

如何正确保存mlr3 lightgbm模型?
EN

Stack Overflow用户
提问于 2021-11-01 07:20:24
回答 1查看 69关注 0票数 2

我有一些下面的代码。保存训练好的模型时遇到错误。只有当我使用lightgbm时才会出错。

代码语言:javascript
复制
library(mlr3)
library(mlr3pipelines)
library(mlr3extralearners)

data = tsk("german_credit")$data()
data = data[, c("credit_risk", "amount", "purpose", "age")]
task = TaskClassif$new("boston", backend = data, target = "credit_risk")

g = po("imputemedian") %>>%
  po("imputeoor") %>>%
  po("fixfactors") %>>%
  po("encodeimpact") %>>% 
  lrn("classif.lightgbm")

gl = GraphLearner$new(g)

gl$train(task)

# predict 
newdata <- data[1,]
gl$predict_newdata(newdata) 
saveRDS(gl, "gl.rds")
代码语言:javascript
复制
# read model from disk ----------------
gl <- readRDS("gl.rds")
newdata <- data[1,]

# error when predict ------------------
gl$predict_newdata(newdata)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-11-01 08:17:06

lightgbmsaveread模型使用特殊函数。您必须在保存之前提取模型,并在加载后将其添加到图形学习器中。然而,这对于基准测试来说可能并不实用。我们会对此进行调查。

代码语言:javascript
复制
library(mlr3)
library(mlr3pipelines)
library(mlr3extralearners)
library(lightgbm)

data = tsk("german_credit")$data()
data = data[, c("credit_risk", "amount", "purpose", "age")]
task = TaskClassif$new("boston", backend = data, target = "credit_risk")

g = po("imputemedian") %>>%
  po("imputeoor") %>>%
  po("fixfactors") %>>%
  po("encodeimpact") %>>% 
  lrn("classif.lightgbm")

gl = GraphLearner$new(g)

gl$train(task)

# save model
saveRDS.lgb.Booster(gl$model$classif.lightgbm$model, "model.rda")

# save graph learner
saveRDS(gl, "gl.rda")

# load model
model = readRDS.lgb.Booster("model.rda")

# load graph learner
gl = readRDS("gl.rda")

# add model to graph learner
gl$state$model$classif.lightgbm$model = model

# predict
newdata <- data[1,]
gl$predict_newdata(newdata)
票数 4
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69793716

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档