首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >将一个层的权重从一个Huggingface模型复制到另一个

将一个层的权重从一个Huggingface模型复制到另一个
EN

Stack Overflow用户
提问于 2021-05-25 13:41:01
回答 1查看 1.5K关注 0票数 1

我有一个预先训练过的模型,我装载如下:

代码语言:javascript
复制
from transformers import BertForSequenceClassification, AdamW, BertConfig, BertModel
model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab.
    num_labels = 2, # The number of output labels--2 for binary classification.
                    # You can increase this for multi-class tasks.   
    output_attentions = False, # Whether the model returns attentions weights.
    output_hidden_states = False, # Whether the model returns all hidden-states.
)

我想要创建一个具有相同架构和随机初始权重的新模型,但嵌入层除外:

代码语言:javascript
复制
==== Embedding Layer ====

bert.embeddings.word_embeddings.weight                  (30522, 768)
bert.embeddings.position_embeddings.weight                (512, 768)
bert.embeddings.token_type_embeddings.weight                (2, 768)
bert.embeddings.LayerNorm.weight                              (768,)
bert.embeddings.LayerNorm.bias                                (768,)

似乎我可以这样做来创建一个具有相同架构的新模型,但是所有的权重都是随机的:

代码语言:javascript
复制
configuration   = model.config
untrained_model = BertForSequenceClassification(configuration)

那么,如何将model的嵌入层权重复制到新的untrained_model中?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-05-25 17:44:37

重量和偏倚只是张量,你可以简单地用_复制它们。

代码语言:javascript
复制
from transformers import BertForSequenceClassification, BertConfig
jetfire = BertForSequenceClassification.from_pretrained('bert-base-cased')
config = BertConfig.from_pretrained('bert-base-cased')

optimus = BertForSequenceClassification(config)

parts = ['bert.embeddings.word_embeddings.weight'
,'bert.embeddings.position_embeddings.weight'              
,'bert.embeddings.token_type_embeddings.weight'    
,'bert.embeddings.LayerNorm.weight'
,'bert.embeddings.LayerNorm.bias']

def joltElectrify (jetfire, optimus, parts):
  target = dict(optimus.named_parameters())
  source = dict(jetfire.named_parameters())

  for part in parts:
    target[part].data.copy_(source[part].data)  

joltElectrify(jetfire, optimus, parts)
票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67689219

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档