我已经创建了自己的BertClassifier模型,从预先训练开始,然后添加由不同层组成的自己的分类头。微调之后,我想使用model.save_pretrained()保存模型,但当我打印它时,从预先训练过的上传它,我没有看到我的分类器头。代码如下。如何将所有的结构保存在我的模型中,并使其完全可以使用AutoModel.from_preatrained('folder_path')访问?谢谢!
class BertClassifier(PreTrainedModel):
"""Bert Model for Classification Tasks."""
config_class = AutoConfig
def __init__(self,config, freeze_bert=True): #tuning only the head
"""
@param bert: a BertModel object
@param classifier: a torch.nn.Module classifier
@param freeze_bert (bool): Set `False` to fine-tune the BERT model
"""
#super(BertClassifier, self).__init__()
super().__init__(config)
# Instantiate BERT model
# Specify hidden size of BERT, hidden size of our classifier, and number of labels
self.D_in = 1024 #hidden size of Bert
self.H = 512
self.D_out = 2
# Instantiate the classifier head with some one-layer feed-forward classifier
self.classifier = nn.Sequential(
nn.Linear(self.D_in, 512),
nn.Tanh(),
nn.Linear(512, self.D_out),
nn.Tanh()
)
def forward(self, input_ids, attention_mask):
# Feed input to BERT
outputs = self.bert(input_ids=input_ids,
attention_mask=attention_mask)
# Extract the last hidden state of the token `[CLS]` for classification task
last_hidden_state_cls = outputs[0][:, 0, :]
# Feed input to classifier to compute logits
logits = self.classifier(last_hidden_state_cls)
return logitsconfiguration=AutoConfig.from_pretrained('Rostlab/prot_bert_bfd')
model = BertClassifier(config=configuration,freeze_bert=False)微调后保存模型
model.save_pretrained('path')加载微调模型
model = AutoModel.from_pretrained('path') 加载后打印模型显示,作为最后一个层,下面是下面的内容,并且缺少了我的2个线性层:
(output): BertOutput(
(dense): Linear(in_features=4096, out_features=1024, bias=True)
(LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
(dropout): Dropout(p=0.0, inplace=False)
(adapters): ModuleDict()
(adapter_fusion_layer): ModuleDict()
)
)
)
)
(pooler): BertPooler(
(dense): Linear(in_features=1024, out_features=1024, bias=True)
(activation): Tanh()
)
(prefix_tuning): PrefixTuningPool(
(prefix_tunings): ModuleDict()
)
)发布于 2022-06-08 18:04:05
也许您的config_class类中的BertClassifier属性有问题。根据文档,您需要创建一个额外的配置类,它继承表单PretrainedConfig,并使用自定义模型的名称初始化model_type属性。
BertClassifier's config_class必须与自定义配置类类型保持一致。之后,您可以通过以下调用注册您的配置和模型:
AutoConfig.register('CustomModelName', CustomModelConfigClass)
AutoModel.register(CustomModelConfigClass, CustomModelClass)并使用AutoModel.from_pretrained('YourCustomModelName')加载已完成的模型
基于代码的不完整示例如下所示:
class BertClassifierConfig(PretrainedConfig):
model_type="BertClassifier"
class BertClassifier(PreTrainedModel):
config_class = BertClassifierConfig
# ...
configuration = BertClassifierConfig()
bert_classifier = BertClassifier(configuration)
# do your finetuning and save your custom model
bert_classifier.save_pretrained("CustomModels/BertClassifier")
# register your config and your model
AutoConfig.register("BertClassifier", BertClassifierConfig)
AutoModel.register(BertClassifierConfig, BertClassifier)
# load your model with AutoModel
bert_classifier_model = AutoModel.from_pretrained("CustomModels/BertClassifier")打印模型输出应与此类似:
(pooler): BertPooler(
(dense): Linear(in_features=768, out_features=768, bias=True)
(activation): Tanh()
)
)
(classifier): Sequential(
(0): Linear(in_features=1024, out_features=512, bias=True)
(1): Tanh()
(2): Linear(in_features=512, out_features=2, bias=True)
(3): Tanh()
(4): Linear(in_features=2, out_features=512, bias=True)
(5): Tanh()
)希望这能有所帮助。
https://stackoverflow.com/questions/72503309
复制相似问题