首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Caret和rpart定义方法

Caret和rpart定义方法
EN

Stack Overflow用户
提问于 2015-05-21 02:04:54
回答 2查看 2.2K关注 0票数 0

我正在试着熟悉一下插入符号包。我以前会直接使用rpart -例如,使用以下语法

代码语言:javascript
复制
fit_rpart=rpart(y~.,data=dt1,method="anova"). 

我指定了anova,因为我的目标是回归(而不是分类)

使用插入符号-I将使用以下语法:

代码语言:javascript
复制
rpart_fit <- train(y ~ ., data = dt1, method = "rpart",trControl=fitControl)

我的问题是,由于方法槽已经被使用,我在哪里/如何仍然可以指定method="anova"?

首先要感谢大家!

EN

回答 2

Stack Overflow用户

发布于 2015-05-27 19:04:52

您可以使用当前的rpart代码创建custom method。首先,获取当前代码:

代码语言:javascript
复制
library(caret)
rpart_code <- getModelInfo("rpart", regex = FALSE)[[1]]

然后,您只需在代码中添加额外的选项。这个方法有点复杂,因为它处理了一堆不同的情况,但这里是编辑:

代码语言:javascript
复制
rpart_code$fit <- function(x, y, wts, param, lev, last, classProbs, ...) { 
  cpValue <- if(!last) param$cp else 0
  theDots <- list(...)
  if(any(names(theDots) == "control")) {
    theDots$control$cp <- cpValue
    theDots$control$xval <- 0 
    ctl <- theDots$control
    theDots$control <- NULL
  } else ctl <- rpart.control(cp = cpValue, xval = 0)   

  ## check to see if weights were passed in (and availible)
  if(!is.null(wts)) theDots$weights <- wts    

  modelArgs <- c(list(formula = as.formula(".outcome ~ ."),
                      data = if(is.data.frame(x)) x else as.data.frame(x),
                      control = ctl,
                      method = "anova"),
                 theDots)
  modelArgs$data$.outcome <- y

  out <- do.call("rpart", modelArgs)

  if(last) out <- prune.rpart(out, cp = param$cp)
  out           
}

然后测试:

代码语言:javascript
复制
library(rpart)
set.seed(445)
mod <- train(pgstat ~ age + eet + g2 + grade + gleason + ploidy, 
             data = stagec,
             method = rpart_code,
             tuneLength = 8)

最大值

票数 1
EN

Stack Overflow用户

发布于 2015-05-27 19:50:30

在插入符号'method‘指的是你想要使用的模型类型,例如rpart或lm (线性回归)或rf (随机森林)。

您所指的内容在插入符号中定义为“指标”。如果您的y变量是连续变量,则度量将默认设置为最大化RMSE。所以你不需要做任何事。

您也可以通过以下方式显式指定:

代码语言:javascript
复制
rpart_fit <- train(y ~ ., data = dt1, method = "rpart",trControl=fitControl, metric="RMSE")
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/30357212

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档