首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用trafo调优SMOTE的k失败:“警告(”k应该小于样本大小!“)”

使用trafo调优SMOTE的k失败:“警告(”k应该小于样本大小!“)”
EN

Stack Overflow用户
提问于 2020-05-13 18:24:17
回答 1查看 331关注 0票数 1

我在使用SMOTE {smotefamily}K参数的trafo函数时遇到了问题,特别是当最近邻居K的数量大于或等于样本大小时,就会返回一个错误(warning("k should be less than sample size!")),调优过程就会终止。

在内部重采样过程中,用户无法控制K小于采样大小。这必须在内部进行控制,例如,如果K的某个值为trafo_K = 2 ^ K >= sample_size,则为trafo_K = sample_size - 1

我想知道这个问题是否有解决方案,或者已经有了解决方案?

代码语言:javascript
复制
library("mlr3") # mlr3 base package
library("mlr3misc") # contains some helper functions
library("mlr3pipelines") # create ML pipelines
library("mlr3tuning") # tuning ML algorithms
library("mlr3learners") # additional ML algorithms
library("mlr3viz") # autoplot for benchmarks
library("paradox") # hyperparameter space
library("OpenML") # to obtain data sets
library("smotefamily") # SMOTE algorithm for imbalance correction

# get list of curated binary classification data sets (see https://arxiv.org/abs/1708.03731v2)
ds = listOMLDataSets(
  number.of.classes = 2,
  number.of.features = c(1, 100),
  number.of.instances = c(5000, 10000)
)
# select imbalanced data sets (without categorical features as SMOTE cannot handle them)
ds = subset(ds, minority.class.size / number.of.instances < 0.2 &
              number.of.symbolic.features == 1)
ds

d = getOMLDataSet(980)
d

# make sure target is a factor and create mlr3 tasks
data = as.data.frame(d)
data[[d$target.features]] = as.factor(data[[d$target.features]])
task = TaskClassif$new(
  id = d$desc$name, backend = data,
  target = d$target.features)
task

# Code above copied from https://mlr3gallery.mlr-org.com/posts/2020-03-30-imbalanced-data/

class_counts <- table(task$truth())
majority_to_minority_ratio <- class_counts[class_counts == max(class_counts)] / 
  class_counts[class_counts == min(class_counts)]

# Pipe operator for SMOTE
po_smote <- po("smote", dup_size = round(majority_to_minority_ratio))

# Random Forest learner
rf <- lrn("classif.ranger", predict_type = "prob")

# Pipeline of Random Forest learner with SMOTE
graph <- po_smote %>>%
  po('learner', rf, id = 'rf')
graph$plot()

# Graph learner
rf_smote <- GraphLearner$new(graph, predict_type = 'prob')
rf_smote$predict_type <- 'prob'

# Parameter set in data table format
ps_table <- as.data.table(rf_smote$param_set)
View(ps_table[, 1:4])

# Define parameter search space for the SMOTE parameters
param_set <- ps_table$id %>%
  lapply(
    function(x) {
      if (grepl('smote.', x)) {
        if (grepl('.dup_size', x)) {
          ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
        } else if (grepl('.K', x)) {
          ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
        }
      }
    }
  )
param_set <- Filter(Negate(is.null), param_set)
param_set <- ParamSet$new(param_set)

# Apply transformation function on SMOTE's K (= The number of nearest neighbors used for sampling new values. See SMOTE().)
param_set$trafo <- function(x, param_set) {
  index <- which(grepl('.K', names(x)))
  if (sum(index) != 0){
    x[[index]] <- round(3 ^ x[[index]]) #  Intentionally define a trafo that won't work
  }
  x
}

# Define and instantiate resampling strategy to be applied within pipeline
cv <- rsmp("cv", folds = 2)
cv$instantiate(task)

# Set up tuning instance
instance <- TuningInstance$new(
  task = task,
  learner = rf_smote,
  resampling = cv,
  measures = msr("classif.bbrier"),
  param_set,
  terminator = term("evals", n_evals = 3), 
  store_models = TRUE)
tuner <- TunerRandomSearch$new()

# Tune pipe learner to find optimal SMOTE parameter values
tuner$optimize(instance)

下面是发生的事情

代码语言:javascript
复制
INFO  [11:00:14.904] Benchmark with 2 resampling iterations 
INFO  [11:00:14.919] Applying learner 'smote.rf' on task 'optdigits' (iter 2/2) 
Error in get.knnx(data, query, k, algorithm) : ANN: ERROR------->
In addition: Warning message:
In get.knnx(data, query, k, algorithm) : k should be less than sample size!

会话信息

代码语言:javascript
复制
R version 3.6.2 (2019-12-12)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 16299)

Matrix products: default

locale:
[1] LC_COLLATE=English_United Kingdom.1252  LC_CTYPE=English_United Kingdom.1252   
[3] LC_MONETARY=English_United Kingdom.1252 LC_NUMERIC=C                           
[5] LC_TIME=English_United Kingdom.1252    

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
 [1] smotefamily_1.3.1        OpenML_1.10              mlr3viz_0.1.1.9002      
 [4] mlr3tuning_0.1.2-9000    mlr3pipelines_0.1.2.9000 mlr3misc_0.2.0          
 [7] mlr3learners_0.2.0       mlr3filters_0.2.0.9000   mlr3_0.2.0-9000         
[10] paradox_0.2.0            yardstick_0.0.5          rsample_0.0.5           
[13] recipes_0.1.9            parsnip_0.0.5            infer_0.5.1             
[16] dials_0.0.4              scales_1.1.0             broom_0.5.4             
[19] tidymodels_0.0.3         reshape2_1.4.3           janitor_1.2.1           
[22] data.table_1.12.8        forcats_0.4.0            stringr_1.4.0           
[25] dplyr_0.8.4              purrr_0.3.3              readr_1.3.1             
[28] tidyr_1.0.2              tibble_3.0.1             ggplot2_3.3.0           
[31] tidyverse_1.3.0         

loaded via a namespace (and not attached):
  [1] utf8_1.1.4              tidyselect_1.0.0        lme4_1.1-21            
  [4] htmlwidgets_1.5.1       grid_3.6.2              ranger_0.12.1          
  [7] pROC_1.16.1             munsell_0.5.0           codetools_0.2-16       
 [10] bbotk_0.1               DT_0.12                 future_1.17.0          
 [13] miniUI_0.1.1.1          withr_2.2.0             colorspace_1.4-1       
 [16] knitr_1.28              uuid_0.1-4              rstudioapi_0.10        
 [19] stats4_3.6.2            bayesplot_1.7.1         listenv_0.8.0          
 [22] rstan_2.19.2            lgr_0.3.4               DiceDesign_1.8-1       
 [25] vctrs_0.2.4             generics_0.0.2          ipred_0.9-9            
 [28] xfun_0.12               R6_2.4.1                markdown_1.1           
 [31] mlr3measures_0.1.3-9000 rstanarm_2.19.2         lhs_1.0.1              
 [34] assertthat_0.2.1        promises_1.1.0          nnet_7.3-12            
 [37] gtable_0.3.0            globals_0.12.5          processx_3.4.1         
 [40] timeDate_3043.102       rlang_0.4.5             workflows_0.1.1        
 [43] BBmisc_1.11             splines_3.6.2           checkmate_2.0.0        
 [46] inline_0.3.15           yaml_2.2.1              modelr_0.1.5           
 [49] tidytext_0.2.2          threejs_0.3.3           crosstalk_1.0.0        
 [52] backports_1.1.6         httpuv_1.5.2            rsconnect_0.8.16       
 [55] tokenizers_0.2.1        tools_3.6.2             lava_1.6.6             
 [58] ellipsis_0.3.0          ggridges_0.5.2          Rcpp_1.0.4.6           
 [61] plyr_1.8.5              base64enc_0.1-3         visNetwork_2.0.9       
 [64] ps_1.3.0                prettyunits_1.1.1       rpart_4.1-15           
 [67] zoo_1.8-7               haven_2.2.0             fs_1.3.1               
 [70] furrr_0.1.0             magrittr_1.5            colourpicker_1.0       
 [73] reprex_0.3.0            GPfit_1.0-8             SnowballC_0.6.0        
 [76] packrat_0.5.0           matrixStats_0.55.0      tidyposterior_0.0.2    
 [79] hms_0.5.3               shinyjs_1.1             mime_0.8               
 [82] xtable_1.8-4            XML_3.99-0.3            tidypredict_0.4.3      
 [85] shinystan_2.5.0         readxl_1.3.1            gridExtra_2.3          
 [88] rstantools_2.0.0        compiler_3.6.2          crayon_1.3.4           
 [91] minqa_1.2.4             StanHeaders_2.21.0-1    htmltools_0.4.0        
 [94] later_1.0.0             lubridate_1.7.4         DBI_1.1.0              
 [97] dbplyr_1.4.2            MASS_7.3-51.4           boot_1.3-23            
[100] Matrix_1.2-18           cli_2.0.1               parallel_3.6.2         
[103] gower_0.2.1             igraph_1.2.4.2          pkgconfig_2.0.3        
[106] xml2_1.2.2              foreach_1.4.7           dygraphs_1.1.1.6       
[109] prodlim_2019.11.13      farff_1.1               rvest_0.3.5            
[112] snakecase_0.11.0        janeaustenr_0.1.5       callr_3.4.1            
[115] digest_0.6.25           cellranger_1.1.0        curl_4.3               
[118] shiny_1.4.0             gtools_3.8.1            nloptr_1.2.1           
[121] lifecycle_0.2.0         nlme_3.1-142            jsonlite_1.6.1         
[124] fansi_0.4.1             pillar_1.4.3            lattice_0.20-38        
[127] loo_2.2.0               fastmap_1.0.1           httr_1.4.1             
[130] pkgbuild_1.0.6          survival_3.1-8          glue_1.4.0             
[133] xts_0.12-0              FNN_1.1.3               shinythemes_1.1.2      
[136] iterators_1.0.12        class_7.3-15            stringi_1.4.4          
[139] memoise_1.1.0           future.apply_1.5.0     

非常感谢。

EN

回答 1

Stack Overflow用户

发布于 2020-05-15 19:38:14

我找到了一种变通方法。

如前所述,问题在于SMOTE {smotefamily}K不能大于或等于样本大小。

我登录到该进程中,发现SMOTE {smotefamily}使用knearest {smotefamily},后者使用knnx.index {FNN},后者又使用get.knn {FNN},后者返回错误warning("k should be less than sample size!"),从而终止mlr3中的调优过程。

现在,在SMOTE {smotefamily}中,knearest {smotefamily}的三个参数是P_setP_setK。从mlr3重采样的角度来看,数据帧P_set是训练数据的交叉验证文件夹的子集,过滤后仅包含少数类的记录。错误所指的“样本大小”是P_set的行数。

因此,随着K通过诸如some_integer ^ K (例如2 ^ K)的trafo而增加,K >= nrow(P_set)变得更有可能。

我们需要确保K永远不会大于或等于P_set

以下是我提出的解决方案:

在定义trafo.

  • Instantiate CV之前,使用rsmp().
  1. Define CV重采样策略where in cv_folds定义变量folds = cv_folds,然后定义CV重采样策略。现在,数据集在每个文件夹中被划分为训练和测试/验证数据。
  2. 在所有训练数据文件夹中找到少数类的最小样本大小,并将其设置为K

的阈值

代码语言:javascript
复制
smote_k_thresh <- 1:cv_folds %>%
  lapply(
    function(x) {
      index <- cv$train_set(x)
      aux <- as.data.frame(task$data())[index, task$target_names]
      aux <- min(table(aux))
    }
  ) %>%
  bind_cols %>%
  min %>%
  unique

  1. 现在将trafo定义如下:

代码语言:javascript
复制
param_set$trafo <- function(x, param_set) {
  index <- which(grepl('.K', names(x)))
  if (sum(index) != 0){
    aux <- round(2 ^ x[[index]])
    if (aux < smote_k_thresh) {
      x[[index]] <- aux
    } else {
      x[[index]] <- sample(smote_k_thresh - 1, 1)
    }
  }
  x
}

换句话说,当被拖拽的K仍然小于样本大小时,请保留它。否则,将其值设置为1到smote_k_thresh - 1之间的任意数字。

Implementation

原始代码稍作修改,以适应建议的调整:

代码语言:javascript
复制
library("mlr3learners") # additional ML algorithms
library("mlr3viz") # autoplot for benchmarks
library("paradox") # hyperparameter space
library("OpenML") # to obtain data sets
library("smotefamily") # SMOTE algorithm for imbalance correction

# get list of curated binary classification data sets (see https://arxiv.org/abs/1708.03731v2)
ds = listOMLDataSets(
  number.of.classes = 2,
  number.of.features = c(1, 100),
  number.of.instances = c(5000, 10000)
)
# select imbalanced data sets (without categorical features as SMOTE cannot handle them)
ds = subset(ds, minority.class.size / number.of.instances < 0.2 &
              number.of.symbolic.features == 1)
ds

d = getOMLDataSet(980)
d

# make sure target is a factor and create mlr3 tasks
data = as.data.frame(d)
data[[d$target.features]] = as.factor(data[[d$target.features]])
task = TaskClassif$new(
  id = d$desc$name, backend = data,
  target = d$target.features)
task

# Code above copied from https://mlr3gallery.mlr-org.com/posts/2020-03-30-imbalanced-data/

class_counts <- table(task$truth())
majority_to_minority_ratio <- class_counts[class_counts == max(class_counts)] / 
  class_counts[class_counts == min(class_counts)]

# Pipe operator for SMOTE
po_smote <- po("smote", dup_size = round(majority_to_minority_ratio))

# Define and instantiate resampling strategy to be applied within pipeline
# Do that BEFORE defining the trafo
cv_folds <- 2
cv <- rsmp("cv", folds = cv_folds)
cv$instantiate(task)

# Calculate max possible value for k-nearest neighbours
smote_k_thresh <- 1:cv_folds %>%
  lapply(
    function(x) {
      index <- cv$train_set(x)
      aux <- as.data.frame(task$data())[index, task$target_names]
      aux <- min(table(aux))
    }
  ) %>%
  bind_cols %>%
  min %>%
  unique

# Random Forest learner
rf <- lrn("classif.ranger", predict_type = "prob")

# Pipeline of Random Forest learner with SMOTE
graph <- po_smote %>>%
  po('learner', rf, id = 'rf')
graph$plot()

# Graph learner
rf_smote <- GraphLearner$new(graph, predict_type = 'prob')
rf_smote$predict_type <- 'prob'

# Parameter set in data table format
ps_table <- as.data.table(rf_smote$param_set)
View(ps_table[, 1:4])

# Define parameter search space for the SMOTE parameters
param_set <- ps_table$id %>%
  lapply(
    function(x) {
      if (grepl('smote.', x)) {
        if (grepl('.dup_size', x)) {
          ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
        } else if (grepl('.K', x)) {
          ParamInt$new(x, lower = 1, upper = round(majority_to_minority_ratio))
        }
      }
    }
  )
param_set <- Filter(Negate(is.null), param_set)
param_set <- ParamSet$new(param_set)

# Apply transformation function on SMOTE's K while ensuring it never equals or exceeds the sample size
param_set$trafo <- function(x, param_set) {
  index <- which(grepl('.K', names(x)))
  if (sum(index) != 0){
    aux <- round(5 ^ x[[index]]) # Try a large value here for the sake of the example
    if (aux < smote_k_thresh) {
      x[[index]] <- aux
    } else {
      x[[index]] <- sample(smote_k_thresh - 1, 1)
    }
  }
  x
}

# Set up tuning instance
instance <- TuningInstance$new(
  task = task,
  learner = rf_smote,
  resampling = cv,
  measures = msr("classif.bbrier"),
  param_set,
  terminator = term("evals", n_evals = 10), 
  store_models = TRUE)
tuner <- TunerRandomSearch$new()

# Tune pipe learner to find optimal SMOTE parameter values
tuner$optimize(instance)

# Here are the original K values
instance$archive$data

# And here are their transformations
instance$archive$data$opt_x
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/61772147

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档