首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >用Split (k,n)将数据拆分成5个子集&而不是用sample()

用Split (k,n)将数据拆分成5个子集&而不是用sample()
EN

Stack Overflow用户
提问于 2022-03-15 11:41:31
回答 1查看 234关注 0票数 0

我想拆分、train、测试,但是在R中使用()函数,而不是和sample() 。

我有58行和28列在我的数据集(一个csv文件),我想做一个10倍或5倍的简历在这个数据集。

我该如何为这个任务写下代码呢?

我试过:

代码语言:javascript
复制
set.seed(1)
smp_size=choose(58,5, name_dataset) # which is totally wrong but ... 
# I haven't figured out yet how to take 5 subsets from 58 observations
# each time I do a 5/10 -fold  CV

train_ind=sample(seq_len(nrow(name_dataset)),size=smp_size) # I think sample here is wrong too
train=name_dataset[train_ind,]
test=name_dataset[-train_ind,]
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-03-15 12:40:29

我不知道你说的每一个5-子集的组合是什么意思。这似乎是一种难以置信的巨大可能性。我假设您的意思是您需要包含数据集中所有样本的5个数据集的子集。我可能会做这样的事。我们首先生成一个组向量,即k的个数和数据集的长度。然后,我们随机地对这些组进行采样,并将数据集按这些分组拆分。

代码语言:javascript
复制
library(tidyverse)

set.seed(3465)
test_data <- tibble(A = runif(58),
                    B = runif(58))


k_split <- function(dat,k, seed = 1){
  set.seed(seed)
  grp <- rep(1:k, length.out = nrow(dat))
  dat |>
    mutate(grp = sample(grp, nrow(dat), replace = F)) |>
    group_split(grp)|>
    map(\(d) select(d, -grp))
}

k_split(test_data, 5)
#> [[1]]
#> # A tibble: 12 x 2
#>        A      B
#>    <dbl>  <dbl>
#>  1 0.476 0.468 
#>  2 0.636 0.639 
#>  3 0.334 0.0269
#>  4 0.668 0.220 
#>  5 0.398 0.919 
#>  6 0.343 0.748 
#>  7 0.799 0.526 
#>  8 0.710 0.759 
#>  9 0.737 0.927 
#> 10 0.819 0.441 
#> 11 0.852 0.656 
#> 12 0.416 0.541 
#> 
#> [[2]]
#> # A tibble: 12 x 2
#>         A      B
#>     <dbl>  <dbl>
#>  1 0.0107 0.905 
#>  2 0.109  0.539 
#>  3 0.715  0.778 
#>  4 0.523  0.416 
#>  5 0.609  0.357 
#>  6 0.152  0.0972
#>  7 0.919  0.450 
#>  8 0.866  0.510 
#>  9 0.0347 0.0890
#> 10 0.862  0.465 
#> 11 0.364  0.765 
#> 12 0.789  0.601 
#> 
#> [[3]]
#> # A tibble: 12 x 2
#>         A      B
#>     <dbl>  <dbl>
#>  1 0.580  0.228 
#>  2 0.201  0.0418
#>  3 0.0359 0.417 
#>  4 0.521  0.758 
#>  5 0.534  0.974 
#>  6 0.580  0.563 
#>  7 0.844  0.781 
#>  8 0.756  0.271 
#>  9 0.211  0.533 
#> 10 0.851  0.764 
#> 11 0.885  0.150 
#> 12 0.262  0.371 
#> 
#> [[4]]
#> # A tibble: 11 x 2
#>         A     B
#>     <dbl> <dbl>
#>  1 0.556  0.313
#>  2 0.353  0.821
#>  3 0.0959 0.861
#>  4 0.759  0.261
#>  5 0.207  0.772
#>  6 0.668  0.527
#>  7 0.150  0.788
#>  8 0.0939 0.257
#>  9 0.0913 0.817
#> 10 0.294  0.790
#> 11 0.0224 0.253
#> 
#> [[5]]
#> # A tibble: 11 x 2
#>          A      B
#>      <dbl>  <dbl>
#>  1 0.0893  0.665 
#>  2 0.966   0.142 
#>  3 0.672   0.0849
#>  4 0.641   0.155 
#>  5 0.490   0.187 
#>  6 0.00394 0.295 
#>  7 0.126   0.813 
#>  8 0.202   0.474 
#>  9 0.0740  0.107 
#> 10 0.412   0.709 
#> 11 0.509   0.253
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71481639

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档