首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >RoBERTa分类RuntimeError:形状'[-1,9]‘对于大小为8的输入无效

RoBERTa分类RuntimeError:形状'[-1,9]‘对于大小为8的输入无效
EN

Stack Overflow用户
提问于 2020-01-10 23:20:22
回答 1查看 740关注 0票数 0
代码语言:javascript
复制
    m = MultiLabelBinarizer()

    X = pd.read_csv('data/data.csv', sep=None, engine='python')
    X = X.dropna()

    Y_train = m.fit_transform(X['labels'])
    Y_train2 = [list(i) for i in Y_train]

    data = pd.DataFrame({'text': pd.Series(X[text_col]), 'labels': Y_train2})
    data = data.dropna()

    train_df, eval_df = train_test_split(data, test_size=0.2)

    numLabels = len(pd.unique(X['labels])) # count of the labels

    model = MultiLabelClassificationModel('roberta', 'roberta-base', num_labels=numLabels, use_cuda=False)

    model.train_model(pd.DataFrame(train_df))

我的标签列的数据结构是:[0,1,0,0,0,1,0, 0,1,1,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,1...]对于每一行,在label-column中都有一个类似于0,1,0,0,0,1,0,0的标签列表

对于文本,每行有一个文本(报纸文章)。

(从该来源获得:https://github.com/ThilinaRajapakse/simpletransformers#minimal-start-for-multilabel-classification)

如果我只用4个条目训练模型,就可以对模型进行训练。但是当我想用整个数据集来训练它时,它告诉我: RuntimeError: shape '-1,9‘对于大小为8的输入是无效的:

代码语言:javascript
复制
 File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/simpletransformers/classification/multi_label_classification_model.py", line 121, in train_model
    return super().train_model(train_df, multi_label=multi_label, eval_df=eval_df, output_dir=output_dir, show_running_loss=show_running_loss, args=args)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/simpletransformers/classification/classification_model.py", line 208, in train_model
    global_step, tr_loss = self.train(train_dataset, output_dir, multi_label=multi_label, show_running_loss=show_running_loss, eval_df=eval_df, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/simpletransformers/classification/classification_model.py", line 306, in train
    outputs = model(**inputs)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/nn/modules/module.py", line 541, in __call__
    result = self.forward(*input, **kwargs)
  File "/Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/simpletransformers/custom_models/models.py", line 117, in forward
    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
RuntimeError: shape '[-1, 9]' is invalid for input of size 8

我不知道8的尺寸是从哪里来的,也不知道现在该怎么做,因为它只有很少的条目。有人能帮上忙吗?

EN

回答 1

Stack Overflow用户

发布于 2020-05-03 22:47:18

0,1,0,0,0,1,0,0 -它是8个大小,但你的模型期望大小为9。这意味着,你的numLabels = 9。如果你有9个类,那么标签列中的标签列表应该是这样的: 0,1,0,0,0,1,0,0,0。但我认为你只需要通过num_labels作为8

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/59684472

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档