首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在使用插入符号包运行的GBM中,我希望使用AUPRC作为性能度量。如何使用自定义度量(如auprc )?

在使用插入符号包运行的GBM中,我希望使用AUPRC作为性能度量。如何使用自定义度量(如auprc )?
EN

Stack Overflow用户
提问于 2018-01-02 01:05:39
回答 2查看 874关注 0票数 3

因为我有不平衡的分类器,所以我尝试使用AUPRC作为我的gbm模型适合的自定义度量。但是,当我试图合并自定义度量时,我将得到代码中提到的以下错误。不知道我做错了什么。

此外,当我内联地运行auprcSummary()时,它可以独立工作。当我试图将它合并到列车()中时,它给了我一个错误。

代码语言:javascript
复制
     library(dplyr) # for data manipulation
     library(caret) # for model-building
     library(pROC) # for AUC calculations
     library(PRROC) # for Precision-Recall curve calculations

    auprcSummary <- function(data, lev = NULL, model = NULL){
      index_class2 <- data$Class == "Class2"
      index_class1 <- data$Class == "Class1"
      the_curve <- pr.curve(data$Class[index_class2],
                    data$Class[index_class1],
                    curve = FALSE)
      out <- the_curve$auc.integral
      names(out) <- "AUPRC"
      out
      }

    ctrl <- trainControl(method = "repeatedcv",
                 number = 10,
                 repeats = 5,
                 summaryFunction = auprcSummary,
                 classProbs = TRUE)

    set.seed(5627)
    orig_fit <- train(Class ~ .,
              data = toanalyze.train,
              method = "gbm",
              verbose = FALSE,
              metric = "AUPRC",
              trControl = ctrl)

这是我正在犯的错误:

代码语言:javascript
复制
     Error in order(scores.class0) : argument 1 is not a vector  

是因为pr.curve()只接受数字向量作为输入(分数/概率)吗?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2018-01-02 08:22:01

我认为这种方法产生了一个适当的自定义摘要函数:

代码语言:javascript
复制
library(caret) 
library(pROC) 
library(PRROC)
library(mlbench) #for the data set

data(Ionosphere)

pr.curve函数中,可以为每个类的数据点分别提供分类分数,例如,作为来自正/前景类的数据点的scores.class0,以及作为负值/背景类的数据点的scores.class1;或者提供所有数据点的分类分数为scores.class0,并以数值形式提供标签(正类为1,负值类为0)作为weights.class0 (我在函数的帮助下复制了这一点,如果不清楚的话)。

我选择为scores.class0中的所有人提供稍后的概率,在weights.class0中提供类分配的概率。

卡雷特指出,如果将classProbs对象的trainControl参数设置为TRUE,数据中将出现包含类概率的其他列。因此,对于Ionosphere数据列,应该存在goodbad

代码语言:javascript
复制
levels(Ionosphere$Class)
#output
[1] "bad"  "good"

要转换为0/1标记,只需:

代码语言:javascript
复制
as.numeric(Ionosphere$Class) - 1

good将成为1

bad将成为0

现在我们有了自定义函数的所有数据。

代码语言:javascript
复制
auprcSummary <- function(data, lev = NULL, model = NULL){
  prob_good <- data$good #take the probability of good class
  the_curve <- pr.curve(scores.class0 = prob_good,
                        weights.class0 = as.numeric(data$obs)-1, #provide the class labels as 0/1
                        curve = FALSE)
  out <- the_curve$auc.integral
  names(out) <- "AUPRC"
  out
}

不需要单独使用data$good来处理这个数据集,可以提取类名并使用类名获取所需的列:

代码语言:javascript
复制
  lvls <- levels(data$obs)
  prob_good <- data[,lvls[2]]

每次更新summaryFunction时都需要注意更新trainControl对象,这一点很重要。

代码语言:javascript
复制
ctrl <- trainControl(method = "repeatedcv",
                     number = 10,
                     repeats = 5,
                     summaryFunction = auprcSummary,
                     classProbs = TRUE)

orig_fit <- train(y = Ionosphere$Class, x = Ionosphere[,c(1,3:34)], #omit column 2 to avoid a bunch of warnings related to the data set
                  method = "gbm",
                  verbose = FALSE,
                  metric = "AUPRC",
                  trControl = ctrl)

orig_fit$results
#output
  shrinkage interaction.depth n.minobsinnode n.trees     AUPRC    AUPRCSD
1       0.1                 1             10      50 0.9722775 0.03524882
4       0.1                 2             10      50 0.9758017 0.03143379
7       0.1                 3             10      50 0.9739880 0.03316923
2       0.1                 1             10     100 0.9786706 0.02502183
5       0.1                 2             10     100 0.9817447 0.02276883
8       0.1                 3             10     100 0.9772322 0.03301064
3       0.1                 1             10     150 0.9809693 0.02078601
6       0.1                 2             10     150 0.9824430 0.02284361
9       0.1                 3             10     150 0.9818318 0.02287886

似乎是合理的

票数 1
EN

Stack Overflow用户

发布于 2018-01-02 23:07:53

caret有一个名为prSummary的内置函数,用于为您计算该函数。你不必自己写。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/48054520

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档