首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在没有ski-kit学习的情况下为K-Fold交叉验证创建训练集?

如何在没有ski-kit学习的情况下为K-Fold交叉验证创建训练集?
EN

Stack Overflow用户
提问于 2020-03-09 11:03:07
回答 1查看 545关注 0票数 2

我有一个包含95行9列的数据集,并希望进行5次交叉验证。在训练中,前8列(特征)用于预测第九列。我的测试集是正确的,但是我的x训练集的大小是(4, 19 ,9),而它应该只有8列,我的y训练集是(4,9),而它应该有19行。我对子数组的索引不正确吗?

代码语言:javascript
复制
kdata = data[0:95,:] # Need total rows to be divisible by 5, so ignore last 2 rows 
np.random.shuffle(kdata) # Shuffle all rows
folds = np.array_split(kdata, k) # each fold is 19 rows x 9 columns

for i in range (k-1):
    xtest = folds[i][:,0:7] # Set ith fold to be test
    ytest = folds[i][:,8]
    new_folds = np.delete(folds,i,0)
    xtrain = new_folds[:][:][0:7] # training set is all folds, all rows x 8 cols
    ytrain = new_folds[:][:][8]   # training y is all folds, all rows x 1 col
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-03-09 11:44:20

欢迎来到Stack Overflow。

一旦你创建了一个新的文件夹,你需要使用np.row_stack()逐行堆叠它们。

另外,我认为你对数组的切片是错误的,在Python或Numpy中,切片行为是[inclusive:exclusive],因此,当你指定切片为[0:7]时,你只取了7列,而不是你想要的8个特征列。

类似地,如果您在for循环中指定了5折,则应该是range(k),它会给出[0,1,2,3,4],而不是range(k-1),它只会给出[0,1,2,3]

修改后的代码如下:

代码语言:javascript
复制
folds = np.array_split(kdata, k) # each fold is 19 rows x 9 columns
np.random.shuffle(kdata) # Shuffle all rows
folds = np.array_split(kdata, k)

for i in range (k):
    xtest = folds[i][:,:8] # Set ith fold to be test
    ytest = folds[i][:,8]
    new_folds = np.row_stack(np.delete(folds,i,0))
    xtrain = new_folds[:, :8]
    ytrain = new_folds[:,8]

    # some print functions to help you debug
    print(f'Fold {i}')
    print(f'xtest shape  : {xtest.shape}')
    print(f'ytest shape  : {ytest.shape}')
    print(f'xtrain shape : {xtrain.shape}')
    print(f'ytrain shape : {ytrain.shape}\n')

它将为您打印出折叠和所需的训练和测试集形状:

代码语言:javascript
复制
Fold 0
xtest shape  : (19, 8)
ytest shape  : (19,)
xtrain shape : (76, 8)
ytrain shape : (76,)

Fold 1
xtest shape  : (19, 8)
ytest shape  : (19,)
xtrain shape : (76, 8)
ytrain shape : (76,)

Fold 2
xtest shape  : (19, 8)
ytest shape  : (19,)
xtrain shape : (76, 8)
ytrain shape : (76,)

Fold 3
xtest shape  : (19, 8)
ytest shape  : (19,)
xtrain shape : (76, 8)
ytrain shape : (76,)

Fold 4
xtest shape  : (19, 8)
ytest shape  : (19,)
xtrain shape : (76, 8)
ytrain shape : (76,)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60594242

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档