首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >生成NA结果的GBM模型

生成NA结果的GBM模型
EN

Stack Overflow用户
提问于 2014-01-18 07:41:15
回答 2查看 6K关注 0票数 4

我正在尝试运行一个简单的GBM分类模型,以针对随机森林和支持向量机对性能进行基准测试,但我在让模型正确评分时遇到了问题。它不会抛出错误,但预测都是NaN。我用的是mlbench的乳腺癌数据。代码如下:

代码语言:javascript
复制
library(gbm)
library(mlbench)
library(caret)
library(plyr)
library(ada)
library(randomForest)

data(BreastCancer)
bc <- BreastCancer
rm(BreastCancer)

bc$Id <- NULL
bc$Class <- as.factor(mapvalues(bc$Class, c("benign", "malignant"), c("0","1")))

index <- createDataPartition(bc$Class, p = 0.7, list = FALSE)
bc.train <- bc[index, ]
bc.test <- bc[-index, ]

model.gbm <- gbm(Class ~ ., data = bc.train, n.trees = 500)

pred.gbm <- predict(model.gbm, bc.test.ind, n.trees = 500, type = "response")

有没有人能帮我解决我做错了什么?另外,我是否必须转换预测函数的输出?我读到这似乎是GBM预测的一个问题。谢谢。

EN

回答 2

Stack Overflow用户

发布于 2014-11-12 23:49:41

我以前遇到过将因子变量赋给gbm的问题。您可以强制Class变量为字符类型,而不是因子类型,这应该可以做到这一点。

代码语言:javascript
复制
bc$Class <- as.factor(mapvalues(bc$Class, c("benign", "malignant"), c("0","1")))
bc$Class <- as.character(bc$Class)

您的代码应该可以很好地运行,只要确保在predict中调用bc.test (而不是bc.test.ind)即可。

以下是我在进行这些更改后获得的预测值的摘要

代码语言:javascript
复制
> summary(pred.gbm)
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max. 
  0.222   0.222   0.231   0.346   0.573   0.579 

最后一件事,我建议在调用createDataPartition()之前设置一个种子(例如使用set.seed())。否则,您每次运行代码时都会得到不同的训练和测试集。

票数 6
EN

Stack Overflow用户

发布于 2020-06-26 04:40:30

您可以只将标签转换为0和1,但首先存储标签以进行比较:

代码语言:javascript
复制
library(gbm)
library(mlbench)
library(caret)

data(BreastCancer)
bc <- BreastCancer

bc$Id <- NULL
# store the actual labels
labels = bc$Class
bc$Class <- as.numeric(bc$Class)-1
index <- createDataPartition(bc$Class, p = 0.7, list = FALSE)
bc.train <- bc[index, ]
bc.test <- bc[-index, ]

model.gbm <- gbm(Class ~ ., data = bc.train, n.trees = 500,distribution = "bernoulli")

pred.gbm <- predict(model.gbm, bc.test, n.trees = 500, type = "response")

由于只有两个类,我们可以通过调用标签的第一级if p <= 0.5来取回标签,反之亦然:

代码语言:javascript
复制
predicted_labels = levels(labels)[1+(pred.gbm>0.5)]

我们拿出实际的测试标签,并制作一个混淆矩阵,以查看它是否正常工作:

代码语言:javascript
复制
test_labels = labels[-index]

table(predicted_labels,test_labels)
                test_labels
predicted_labels benign malignant
       benign       129         2
       malignant      3        75
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/21198007

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档