首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何保存检查点让变压器gpt2继续培训?

如何保存检查点让变压器gpt2继续培训?
EN

Stack Overflow用户
提问于 2022-02-22 04:36:01
回答 1查看 294关注 0票数 0

我正在重新培训GPT2语言模型,并关注这个博客:

https://towardsdatascience.com/train-gpt-2-in-your-own-language-fc6ad4d60171

在这里,他们已经在GPT2上训练了一个网络,而我也在尝试重新创建一个网络。但是,我的数据集太大了(250 my ),所以我想继续每隔一段时间进行培训。换句话说,我想检查模型的训练。如果有任何帮助,或一段代码,我可以实现的检查点和继续培训,这将对我有很大帮助。谢谢。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-02-22 19:10:58

代码语言:javascript
复制
training_args = TrainingArguments(
    output_dir=model_checkpoint,
    # other hyper-params
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=dev_set,
    tokenizer=tokenizer
)

trainer.train()
# Save the model to model_dir
trainer.save_model()

def prepare_model(tokenizer, model_name_path):
    model = AutoModelForCausalLM.from_pretrained(model_name_path)
    model.resize_token_embeddings(len(tokenizer))
    return model

# Assume tokenizer is defined, You can simply pass the saved model directory path.
model = prepare_model(tokenizer, model_checkpoint)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71215965

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档