首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何创建自定义模型(使用插入符号中的循环/子模型技巧)

如何创建自定义模型(使用插入符号中的循环/子模型技巧)
EN

Stack Overflow用户
提问于 2018-08-27 05:48:43
回答 1查看 397关注 0票数 2

我与这个问题斗争了令人尴尬的很长时间。我觉得自己绝对是个笨蛋,因为答案可能是显而易见的,但我找不到一个线程来解释如何做到这一点。

关于自定义模型创建的文档部分对我来说就像this一样。我感觉在我的教育期间,我错过了一些非常具体的课程,现在每个人都记得了,除了我,因为我找到的只是“是的,只需创建一个自定义模型,然后完成”。

这里有一些实际问题:

我希望在caret中获得gbm的每一次迭代的预测结果。例如,在gbm中,我可以在predict(..., n.trees = 1:100)中使用n.trees,这就完成了。

caret中,显然我需要使用子模型技巧,这意味着-如果我理解正确的话-我必须创建自己的自定义模型。

但是我可以在getModelInfo('gbm')中看到,有某种循环函数!

代码语言:javascript
复制
$gbm$loop
function (grid) 
{
    loop <- plyr::ddply(grid, c("shrinkage", "interaction.depth", 
        "n.minobsinnode"), function(x) c(n.trees = max(x$n.trees)))
    submodels <- vector(mode = "list", length = nrow(loop))
    for (i in seq(along = loop$n.trees)) {
        index <- which(grid$interaction.depth == loop$interaction.depth[i] & 
            grid$shrinkage == loop$shrinkage[i] & grid$n.minobsinnode == 
            loop$n.minobsinnode[i])
        trees <- grid[index, "n.trees"]
        submodels[[i]] <- data.frame(n.trees = trees[trees != 
            loop$n.trees[i]])
    }
    list(loop = loop, submodels = submodels)

我该如何使用它?为什么它在默认情况下不工作?我真的需要创建自定义模型吗?或许不需要?

免责声明1:我不想使用任何交叉验证。我只想为单个gbm运行的每一次迭代提取预测。

免责声明2:我不想在$finalModel上使用predict.gbm(),因为我还想测试一些其他算法,这些算法也使用子模型技巧。我不想使用所有不同的特定于算法的predict()函数,因为那样我为什么还要费心使用插入符号。

我甚至不知道我应该把什么作为一个可复制的例子。代码没有问题。我只是不知道这东西是怎么工作的。

EN

回答 1

Stack Overflow用户

发布于 2018-08-27 17:56:14

下面是一个关于如何为每棵树的测试数据提取所需预测的示例:

代码语言:javascript
复制
library(caret)
library(mlbench) #for the data set
data(Sonar) #some data set I always use on stack overflow

res <- train(Class~.,
             data = Sonar,
             method = "gbm",
             trControl = trainControl(method = "cv", #some evaluations scheme
                                      number = 5,
                                      savePredictions = "all"), #tell caret you would like to save all,
             tuneGrid = expand.grid(shrinkage = 0.01,
                                    interaction.depth = 2, 
                                    n.minobsinnode = 10,
                                    n.trees = 1:100)) #some random values and all the trees

res$pred #results are stored in here

基本上,您在文章中显示的代码告诉脱字符不要调优所有的n.tree模型,而是只对每个超参数组合使用max(n.trees)调优一个模型,然后使用它来获得n.trees < max(n.trees)的预测

一些情节

代码语言:javascript
复制
library(ggplot2)

ggplot(res$results)+
  geom_line(aes(x = n.trees, y = Accuracy))

您也可以选择不使用savePredictions = "all",因为这会导致需要大量内存的序列对象。而是使用res$results,您可以在其中计算所有需要的指标。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/52030499

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档