首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >仅用插入符号为最终模型指定选项

仅用插入符号为最终模型指定选项
EN

Stack Overflow用户
提问于 2018-11-20 12:31:15
回答 1查看 958关注 0票数 1

上下文

我正在使用caret来拟合和优化模型。通常,最佳参数是使用交叉验证等重采样方法找到的。一旦选择了最优的参数,就会使用最优的参数集对整个训练数据进行拟合。

除了要调优的参数(通过tuneGrid传递)外,还可以将参数传递给正在调用的底层算法,方法是将它们传递给train

我的问题

是否有任何方法可以指定特定于模型的选项,仅用于最终模型?

为了更清晰起见:我确实想对所有中间模型进行拟合(以获得可靠的性能评估),但我想用不同的参数(除了最佳参数)对最终模型进行拟合。

具体用例

假设我想在一些数据中安装一个bartMachine,然后在生产中使用最后的模型。我通常会将调优模型保存到磁盘,并根据需要加载它。但是,我只能保存/加载已序列化的bartMachine模型,即需要通过caret::trainserialize=T传递给bartMachine

但这将序列化所有的模型,这是非常不切实际的。我只需要序列化最终的模型。有什么办法吗?

代码语言:javascript
复制
library("caret")
library("bartMachine")
tgrid <- expand.grid(num_trees = 100,
                       k = c(2, 3),
                       alpha = 0.95, 
                       beta = 2,
                       nu =  3)
# The printed log shows that all intermediate models are being serialized
fit <- train(hp ~ ., 
             data=mtcars, 
             method="bartMachine",
             serialize=T,
             tuneGrid=tgrid,
             trControl = trainControl(method="cv", 5, verboseIter=T))
EN

回答 1

Stack Overflow用户

发布于 2018-11-20 12:54:06

若要在不进行参数调整或重采样的情况下将模型与整个数据集相匹配,请将列车控制方法修改为“无”:

代码语言:javascript
复制
tgrid <- expand.grid(num_trees = 100,
                     k = 2,
                     alpha = 0.95, 
                     beta = 2,
                     nu =  3)
fit <- train(hp ~ ., 
             data=mtcars, 
             method="bartMachine",
             serialize=TRUE,
             tuneGrid=tgrid,
             trControl = trainControl(method="none"))

注意,我已经删除了问题代码中的两个k值中的一个。否则就会出现一个错误:Only one model should be specified in tuneGrid with no resampling。我建议用另一个k值建立一个单独的模型。

上面的代码提供了以下输出:

代码语言:javascript
复制
bartMachine initializing with 100 trees...
bartMachine vars checked...
bartMachine java init...
bartMachine factors created...
bartMachine before preprocess...
bartMachine after preprocess... 11 total features...
bartMachine sigsq estimated...
bartMachine training data finalized...
Now building bartMachine for regression ...
building BART with mem-cache speedup...
Iteration 100/1250  mem: 17.6/477.1MB
Iteration 200/1250  mem: 25.1/477.1MB
Iteration 300/1250  mem: 30.8/477.1MB
Iteration 400/1250  mem: 39.9/477.1MB
Iteration 500/1250  mem: 19/477.1MB
Iteration 600/1250  mem: 59.6/477.1MB
Iteration 700/1250  mem: 39.6/477.1MB
Iteration 800/1250  mem: 79.8/477.1MB
Iteration 900/1250  mem: 119.9/477.1MB
Iteration 1000/1250  mem: 40.7/477.1MB
Iteration 1100/1250  mem: 80.8/477.1MB
Iteration 1200/1250  mem: 121/477.1MB
done building BART in 1.289 sec 

burning and aggregating chains from all threads... done
evaluating in sample data...done
serializing in order to be saved for future R sessions...done

fit$finalModel中,序列化参数设置为TRUE。

代码语言:javascript
复制
fit$finalModel$serialize
[1] TRUE

值得注意的是,bartMachine内部check_serialization函数不提供任何警告或错误(或任何其他输出):

代码语言:javascript
复制
bartMachine:::check_serialization(fit$finalModel)

我不清楚如何从fit$finalModel中提取序列化对象。我假设它存储在fit$finalModel$java_bart_machine中,其中包含一个rJava指针。使用rJava包( bartMachine所依赖的)可能会获得进一步的洞察力。

更新:@antoine在下面的注释中声明:"serialize=T不会导致模型被保存,而是将样本序列化到模型中,这意味着当模型被写入磁盘时会保存它们“。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/53393050

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档