我正在解决一个多类分类问题,并试图使用广义增强模型(gbm包,R)。我面临的问题是:caret的train函数与method="gbm"似乎不能正确地处理多类数据。下面是一个简单的例子。
library(gbm)
library(caret)
data(iris)
fitControl <- trainControl(method="repeatedcv",
number=5,
repeats=1,
verboseIter=TRUE)
set.seed(825)
gbmFit <- train(Species ~ ., data=iris,
method="gbm",
trControl=fitControl,
verbose=FALSE)
gbmFit输出是
+ Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150
predictions failed for Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150
- Fold1.Rep1: interaction.depth=1, shrinkage=0.1, n.trees=150
+ Fold1.Rep1: interaction.depth=2, shrinkage=0.1, n.trees=150
...
+ Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150
predictions failed for Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150
- Fold5.Rep1: interaction.depth=3, shrinkage=0.1, n.trees=150
Aggregating results
Selecting tuning parameters
Fitting interaction.depth = numeric(0), n.trees = numeric(0), shrinkage = numeric(0) on full training set
Error in if (interaction.depth < 1) { : argument is of length zero然而,如果我尝试使用gbm没有插入包装,我得到了很好的结果。
set.seed(1365)
train <- createDataPartition(iris$Species, p=0.7, list=F)
train.iris <- iris[train,]
valid.iris <- iris[-train,]
gbm.fit.iris <- gbm(Species ~ ., data=train.iris, n.trees=200, verbose=FALSE)
gbm.pred <- predict(gbm.fit.iris, valid.iris, n.trees=200, type="response")
gbm.pred <- as.factor(colnames(gbm.pred)[max.col(gbm.pred)]) ##!
confusionMatrix(gbm.pred, valid.iris$Species)$overallFYI,由##!标记的在线代码将predict.gbm返回的类概率矩阵转换为大多数可能类的因子。输出是
Accuracy Kappa AccuracyLower AccuracyUpper AccuracyNull AccuracyPValue McnemarPValue
9.111111e-01 8.666667e-01 7.877883e-01 9.752470e-01 3.333333e-01 8.467252e-16 NaN 如何使插入符号在多类数据上与gbm正常工作,有什么建议吗?
UPD:
sessionInfo()
R version 2.15.3 (2013-03-01)
Platform: x86_64-pc-linux-gnu (64-bit)
locale:
[1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C LC_TIME=en_US.UTF-8 LC_COLLATE=en_US.UTF-8
[5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8 LC_PAPER=C LC_NAME=C
[9] LC_ADDRESS=C LC_TELEPHONE=C LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] splines stats graphics grDevices utils datasets methods base
other attached packages:
[1] e1071_1.6-1 class_7.3-5 gbm_2.0-8 survival_2.36-14 caret_5.15-61 reshape2_1.2.2 plyr_1.8
[8] lattice_0.20-13 foreach_1.4.0 cluster_1.14.3 compare_0.2-3
loaded via a namespace (and not attached):
[1] codetools_0.2-8 compiler_2.15.3 grid_2.15.3 iterators_1.0.6 stringr_0.6.2 tools_2.15.3 发布于 2013-03-23 21:08:39
这是我现在正在研究的问题。
如果您发布了sessionInfo()的结果,这将有所帮助。
另外,从https://code.google.com/p/gradientboostedmodels/中获取最新的gbm可能会解决这个问题。
最大值
发布于 2014-08-10 06:15:32
更新:卡雷特可以做多类分类.
您应该确保类标签是alpha-数字格式(从字母开始)。
例如:如果数据的标签为"1“、"2”、"3“,则将这些标签更改为"Seg1”、"Seg2“和"Seg3",否则将使用fail插入。
发布于 2015-01-02 08:07:04
更新:原始代码确实运行并产生以下输出
+ Fold1.Rep1: shrinkage=0.1, interaction.depth=1, n.trees=150
- Fold1.Rep1: shrinkage=0.1, interaction.depth=1, n.trees=150
...
...
...
+ Fold5.Rep1: shrinkage=0.1, interaction.depth=3, n.trees=150
- Fold5.Rep1: shrinkage=0.1, interaction.depth=3, n.trees=150
Aggregating results
Selecting tuning parameters
Fitting n.trees = 50, interaction.depth = 2, shrinkage = 0.1 on full training set
> gbmFit
Stochastic Gradient Boosting
150 samples
4 predictor
3 classes: 'setosa', 'versicolor', 'virginica'
No pre-processing
Resampling: Cross-Validated (5 fold, repeated 1 times)
Summary of sample sizes: 120, 120, 120, 120, 120
Resampling results across tuning parameters:
interaction.depth n.trees Accuracy Kappa Accuracy SD
1 50 0.9400000 0.91 0.04346135
1 100 0.9400000 0.91 0.03651484
1 150 0.9333333 0.90 0.03333333
2 50 0.9533333 0.93 0.04472136
2 100 0.9533333 0.93 0.05055250
2 150 0.9466667 0.92 0.04472136
3 50 0.9333333 0.90 0.03333333
3 100 0.9466667 0.92 0.04472136
3 150 0.9400000 0.91 0.03651484
Kappa SD
0.06519202
0.05477226
0.05000000
0.06708204
0.07582875
0.06708204
0.05000000
0.06708204
0.05477226
Tuning parameter 'shrinkage' was held constant at a value of 0.1
Accuracy was used to select the optimal model using the
largest value.
The final values used for the model were n.trees =
50, interaction.depth = 2 and shrinkage = 0.1.
> summary(gbmFit)
var rel.inf
Petal.Length Petal.Length 74.1266408
Petal.Width Petal.Width 22.0668983
Sepal.Width Sepal.Width 3.2209288
Sepal.Length Sepal.Length 0.5855321https://stackoverflow.com/questions/15585501
复制相似问题