首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >有没有办法为rpart模型中的每个节点获取gini指标值?

有没有办法为rpart模型中的每个节点获取gini指标值?
EN

Data Science用户
提问于 2020-07-07 11:45:11
回答 1查看 994关注 0票数 4
代码语言:javascript
复制
df <- tibble(x=factor(c("A", "B")), y=factor(c(1, 0)))
model <- rpart(formula=y~., data=df, method="class", control=rpart.control(minsplit=2))

这里的模型将有一个父节点和两个子节点。如何从rpart模型对象中获取这些节点的gini索引值?

EN

回答 1

Data Science用户

发布于 2022-02-16 13:58:47

下面的代码应该使用任意数量的类来计算rpart分类树的Gini索引:

代码语言:javascript
复制
gini  <- function(tree){
  # calculate gini index for `rpart` tree
  ylevels <- attributes(tree)[["ylevels"]]
  nclass <- length(ylevels)
  yval2 <- tree[["frame"]][["yval2"]]
  vars <- tree[["frame"]][["var"]]
  labls = labels(tree)
  df = data.frame(matrix(nrow=length(labls), ncol=5))
  colnames(df) <- c("Name", "GiniIndex", "Class", "Items", "ItemProbs")
  
  for(i in 1:length(vars)){
    row <- yval2[i , ]
    node.class <- row[1]
    j <- 2
    node.class_counts = row[j:(j+nclass-1)]
    j <- j+nclass
    node.class_probs = row[j:(j+nclass-1)]
    
    gini = 1-sum(node.class_probs^2)
    gini = round(gini,5)
    name = paste(vars[i], " (", labls[i], ")")
    df[i,] = c(name, gini, node.class, toString(round(node.class_counts,5)), toString(round(node.class_probs,5)))
  }
  return(df)
}


> df <- data.frame(x=factor(c("A", "B", "C", "C", "D")), y=factor(c(1, 2, 3, 3, 4)))
> model <- rpart(formula=y~., data=df, method="class", control=rpart.control(minsplit=2))
> gini(model)
             Name GiniIndex Class      Items                    ItemProbs
1     x  ( root )      0.72     3 1, 1, 2, 1           0.2, 0.2, 0.4, 0.2
2    x  ( x=abd )   0.66667     1 1, 1, 0, 1 0.33333, 0.33333, 0, 0.33333
3 <leaf>  ( x=a )         0     1 1, 0, 0, 0                   1, 0, 0, 0
4     x  ( x=bd )       0.5     2 0, 1, 0, 1               0, 0.5, 0, 0.5
5 <leaf>  ( x=b )         0     2 0, 1, 0, 0                   0, 1, 0, 0
6 <leaf>  ( x=d )         0     4 0, 0, 0, 1                   0, 0, 0, 1
7 <leaf>  ( x=c )         0     3 0, 0, 2, 0                   0, 0, 1, 0


# don't know how to publish plots on StackExchange:
# rpart.plot(model, extra=104, box.palette="Blues", fallen.leaves=FALSE)
票数 0
EN
页面原文内容由Data Science提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://datascience.stackexchange.com/questions/77302

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档