首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >熊猫数据中多标签分类数据集的迭代分割

熊猫数据中多标签分类数据集的迭代分割
EN

Stack Overflow用户
提问于 2022-03-23 09:47:24
回答 1查看 344关注 0票数 1

我有dataset,它包含带有字符串值的文本列和值为1或0的多列(分类或否)。我想使用skmultilearn将数据分割成均匀分布,但我得到了以下错误:

代码语言:javascript
复制
KeyError: 'key of type tuple not found and not a MultiIndex'

这是我的密码:

代码语言:javascript
复制
import pandas as pd
from skmultilearn.model_selection import iterative_train_test_split


y = pd.read_csv("dataset.csv")
x = y.pop("text")

x_train, x_test, y_train, y_test = iterative_train_test_split(x, y, test_size=0.1)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-03-23 15:30:09

下面是对我有用的东西(这是98/1/1分离):

代码语言:javascript
复制
import os
import pandas as pd
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit


def main():
    # load dataset
    y = pd.read_csv("dataset.csv")
    x = y.pop("text")

    # save tag names to reuse them later for creating pandas DataFrames
    tag_names = y.columns

    # Data has to be in ndarray format
    y = y.to_numpy()
    x = x.to_numpy()

    # split to train / test
    msss = MultilabelStratifiedShuffleSplit(n_splits=2, test_size=0.02, random_state=42)
    for train_index, test_index in msss.split(x, y):
        x_train, x_test_temp = x[train_index], x[test_index]
        y_train, y_test_temp = y[train_index], y[test_index]

    # make some memory space
    del x
    del y

    # split to test / validation
    msss = MultilabelStratifiedShuffleSplit(n_splits=2, test_size=0.5, random_state=42)
    for test_index, val_index in msss.split(x_test_temp, y_test_temp):
        x_test, x_val = x_test_temp[test_index], x_test_temp[val_index]
        y_test, y_val = y_test_temp[test_index], y_test_temp[val_index]

    # train dataset
    df_train = pd.DataFrame(data=y_train, columns=tag_names)
    df_train.insert(0, "snippet", x_train)

    # validation dataset
    df_val = pd.DataFrame(data=y_val, columns=tag_names)
    df_val.insert(0, "snippet", x_val)

    # test dataset
    df_test = pd.DataFrame(data=y_test, columns=tag_names)
    df_test.insert(0, "snippet", x_test)


if __name__ == "__main__":
    main()
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71585013

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档