首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何从插入符号::训练对象绘制最终的c50决策树模型(库C50)

如何从插入符号::训练对象绘制最终的c50决策树模型(库C50)
EN

Stack Overflow用户
提问于 2020-04-06 13:41:51
回答 1查看 1.2K关注 0票数 2

我使用来自插入符号库的train函数训练了决策树模型:

代码语言:javascript
复制
gr = expand.grid(trials = c(1, 10, 20), model = c("tree", "rules"), winnow = c(TRUE, FALSE))
dt = train(y ~ ., data = train, method = "C5.0", trControl = trainControl(method = 'cv', number = 10), tuneGrid = gr)

现在,我想为最终模型绘制决策树。但是这个命令不起作用:

代码语言:javascript
复制
plot(dt$finalModel)

Error in data.frame(eval(parse(text = paste(obj$call)[xspot])), eval(parse(text = paste(obj$call)[yspot])),  : 
  arguments imply differing number of rows: 4160, 208, 0

有人已经在这里问过了:主题

建议的解决方案是使用bestTune列车对象中手动定义相应的c5.0模型。然后绘制c5.0模型,通常如下:

代码语言:javascript
复制
c5model = C5.0(x = x, y = y, trials = dt$bestTune$trials, rules = dt$bestTune$model == "rules", control = C5.0Control(winnow = dt$bestTune$winnow))
plot(c5model)

我试过这样做。是的,它使得绘制c5.0模型成为可能,但是预测了训练对象和手动重新创建c5.0模型的概率不匹配。

因此,我的问题是:是否可以从caret::train对象中提取最终的c5.0模型并绘制此决策树

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-04-06 15:32:35

预测的概率应相同,见下文:

代码语言:javascript
复制
library(MASS)
library(caret)
library(C50)
library(partykit)

traindata = Pima.tr
testdata = Pima.te

gr = expand.grid(trials = c(1, 2), 
model = c("tree"), winnow = c(TRUE, FALSE))

dt = train(x = traindata[,-ncol(testdata)], y = traindata[,ncol(testdata)], 
method = "C5.0",trControl = trainControl(method = 'cv', number=3),tuneGrid=gr)

c5model = C5.0.default(x = traindata[,-ncol(testdata)], y = traindata[,ncol(testdata)], 
trials = dt$bestTune$trials, rules = dt$bestTune$model == "rules", 
control = C5.0Control(winnow = dt$bestTune$winnow))

all.equal(predict(c5model,testdata[,-ncol(testdata)],type="prob"),
predict(dt$finalModel,testdata[,-ncol(testdata)],type="prob"))
[1] TRUE

所以我建议你再检查一下预测是否相同。

从插入符号中看到的绘制最终模型的错误来自于存储在$call下的内容(这很奇怪),我们可以用一个调用来替换它,该调用可以用于绘图:

代码语言:javascript
复制
plot(c5model)

代码语言:javascript
复制
finalMod = dt$finalModel
finalMod$call = c5model$call
plot(finalMod)

或者你可以用你的训练结果重写它,但是你可以看到它的表达变得有点复杂(或者至少我对它不太擅长):

代码语言:javascript
复制
newcall = substitute(C5.0.default(x = X, y = Y, trials = ntrials, rules = RULES, control = C5.0Control(winnow = WINNOW)),
list(
X = quote(traindata[, -ncol(traindata)]),
Y = quote(traindata[, ncol(traindata)]),
RULES = dt$bestTune$model == "rules",
ntrials = dt$bestTune$trials,
WINNOW = dt$bestTune$winnow)
)

finalMod = dt$finalModel
finalMod$call = newcall
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61061218

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档