首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >保存一个具有自定义前向函数的Bert模型并在Hugginface面上运行

保存一个具有自定义前向函数的Bert模型并在Hugginface面上运行
EN

Stack Overflow用户
提问于 2022-06-04 21:38:22
回答 1查看 441关注 0票数 0

我已经创建了自己的BertClassifier模型,从预先训练开始,然后添加由不同层组成的自己的分类头。微调之后,我想使用model.save_pretrained()保存模型,但当我打印它时,从预先训练过的上传它,我没有看到我的分类器头。代码如下。如何将所有的结构保存在我的模型中,并使其完全可以使用AutoModel.from_preatrained('folder_path')访问?谢谢!

代码语言:javascript
复制
class BertClassifier(PreTrainedModel):
    """Bert Model for Classification Tasks."""
    config_class = AutoConfig
    def __init__(self,config, freeze_bert=True): #tuning only the head
        """
         @param    bert: a BertModel object
         @param    classifier: a torch.nn.Module classifier
         @param    freeze_bert (bool): Set `False` to fine-tune the BERT model
        """
        #super(BertClassifier, self).__init__()
        super().__init__(config)

        # Instantiate BERT model
        # Specify hidden size of BERT, hidden size of our classifier, and number of labels
        self.D_in = 1024 #hidden size of Bert
        self.H = 512
        self.D_out = 2
 
        # Instantiate the classifier head with some one-layer feed-forward classifier
        self.classifier = nn.Sequential(
            nn.Linear(self.D_in, 512),
            nn.Tanh(),
            nn.Linear(512, self.D_out),
            nn.Tanh()
        )
 


    def forward(self, input_ids, attention_mask):


         # Feed input to BERT
        outputs = self.bert(input_ids=input_ids,
                             attention_mask=attention_mask)
         
         # Extract the last hidden state of the token `[CLS]` for classification task
        last_hidden_state_cls = outputs[0][:, 0, :]
 
         # Feed input to classifier to compute logits
        logits = self.classifier(last_hidden_state_cls)
 
        return logits
代码语言:javascript
复制
configuration=AutoConfig.from_pretrained('Rostlab/prot_bert_bfd')
model = BertClassifier(config=configuration,freeze_bert=False)

微调后保存模型

代码语言:javascript
复制
model.save_pretrained('path')

加载微调模型

代码语言:javascript
复制
model = AutoModel.from_pretrained('path') 

加载后打印模型显示,作为最后一个层,下面是下面的内容,并且缺少了我的2个线性层:

代码语言:javascript
复制
 (output): BertOutput(
          (dense): Linear(in_features=4096, out_features=1024, bias=True)
          (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.0, inplace=False)
          (adapters): ModuleDict()
          (adapter_fusion_layer): ModuleDict()
        )
      )
    )
  )
  (pooler): BertPooler(
    (dense): Linear(in_features=1024, out_features=1024, bias=True)
    (activation): Tanh()
  )
  (prefix_tuning): PrefixTuningPool(
    (prefix_tunings): ModuleDict()
  )
)
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-06-08 18:04:05

也许您的config_class类中的BertClassifier属性有问题。根据文档,您需要创建一个额外的配置类,它继承表单PretrainedConfig,并使用自定义模型的名称初始化model_type属性。

BertClassifier's config_class必须与自定义配置类类型保持一致。之后,您可以通过以下调用注册您的配置和模型:

代码语言:javascript
复制
AutoConfig.register('CustomModelName', CustomModelConfigClass)
AutoModel.register(CustomModelConfigClass, CustomModelClass)

并使用AutoModel.from_pretrained('YourCustomModelName')加载已完成的模型

基于代码的不完整示例如下所示:

代码语言:javascript
复制
class BertClassifierConfig(PretrainedConfig):
    model_type="BertClassifier"


class BertClassifier(PreTrainedModel):
    config_class = BertClassifierConfig
    # ...


configuration = BertClassifierConfig()
bert_classifier = BertClassifier(configuration)

# do your finetuning and save your custom model
bert_classifier.save_pretrained("CustomModels/BertClassifier")

# register your config and your model
AutoConfig.register("BertClassifier", BertClassifierConfig)
AutoModel.register(BertClassifierConfig, BertClassifier)

# load your model with AutoModel
bert_classifier_model = AutoModel.from_pretrained("CustomModels/BertClassifier")

打印模型输出应与此类似:

代码语言:javascript
复制
    (pooler): BertPooler(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (activation): Tanh()
    )
  )
  (classifier): Sequential(
    (0): Linear(in_features=1024, out_features=512, bias=True)
    (1): Tanh()
    (2): Linear(in_features=512, out_features=2, bias=True)
    (3): Tanh()
    (4): Linear(in_features=2, out_features=512, bias=True)
    (5): Tanh()
  )

希望这能有所帮助。

https://huggingface.co/docs/transformers/custom_models#registering-a-model-with-custom-code-to-the-auto-classes

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/72503309

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档