首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >拥抱面孔训练器:模型初始化错误

拥抱面孔训练器:模型初始化错误
EN

Stack Overflow用户
提问于 2021-03-27 14:41:47
回答 1查看 401关注 0票数 0

我收到以下错误:

代码语言:javascript
复制
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Traceback (most recent call last):
  File "./run_hyperparameter_search.py", line 74, in <module>
    trainer = Trainer(
  File "/ext3/miniconda3/lib/python3.8/site-packages/transformers/trainer.py", line 273, in __init__
    model = self.call_model_init()
  File "/ext3/miniconda3/lib/python3.8/site-packages/transformers/trainer.py", line 737, in call_model_init
    raise RuntimeError("model_init should have 0 or 1 argument.")
RuntimeError: model_init should have 0 or 1 argument.
~                                                                       

这是我在我的拥抱面孔训练器中所说的:

代码语言:javascript
复制
#Initialising the model
trainer = Trainer(
    args = training_args,
    tokenizer = tokenizer,
    train_dataset = train_data,
    eval_dataset = val_data,
    # maybe there is a () in the init, but not in compute metrics for sure. Will test
    model_init = finetuning_utils.model_init(),
    compute_metrics = finetuning_utils.compute_metrics,
)

问题显然出在model_init中。

下面是finetuning_utils.model_init()包含的内容:

代码语言:javascript
复制
def model_init():
    """Returns an initialized model for use in a Hugging Face Trainer."""
    ## TODO: Return a pretrained RoBERTa model for sequence classification.
    ## See https://huggingface.co/transformers/model_doc/roberta.html#robertaforsequenceclassification.
    model = RobertaForSequenceClassification.from_pretrained("roberta-base")
    #model = model.to('cuda')
    return model

请帮助解决此错误。

EN

回答 1

Stack Overflow用户

发布于 2021-04-17 18:33:57

Huggingface trainer docs看起来,model_init接受了一个可调用的。因此,它不应该实例化参数,而应该接受可调用参数本身,即不带括号:

代码语言:javascript
复制
model_init = finetuning_utils.model_init

或者,您可以删除model_init并使用model参数来达到与finetuning_utils.model_init中包含的代码相同的效果,如下所示:

代码语言:javascript
复制
model = RobertaForSequenceClassification.from_pretrained("roberta-base")
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66828699

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档