首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何冻结TFBertForSequenceClassification预训练模型?

如何冻结TFBertForSequenceClassification预训练模型?
EN

Stack Overflow用户
提问于 2020-07-01 07:25:04
回答 4查看 2K关注 0票数 2

如果我使用的是tensorflow版本的拥抱面变压器,我如何冻结预先训练的编码器的权重,以便只优化头部层的权重?

对于PyTorch实现,它是通过

代码语言:javascript
复制
for param in model.base_model.parameters():
    param.requires_grad = False

也希望对tensorflow实现做同样的工作。

EN

回答 4

Stack Overflow用户

回答已采纳

发布于 2020-07-07 01:47:09

找到了办法。编译前冻结基本模型。

代码语言:javascript
复制
model = TFBertForSequenceClassification.from_pretrained("bert-base-uncased")
model.layers[0].trainable = False
model.compile(...)
票数 0
EN

Stack Overflow用户

发布于 2021-02-09 04:44:30

在挖掘了这个线程1之后,我认为下面的代码对TF2不会有什么影响。即使在特定的情况下它可能是多余的。

代码语言:javascript
复制
 model = TFBertModel.from_pretrained('./bert-base-uncase')
 for layer in model.layers:
    layer.trainable=False
    for w in layer.weights: w._trainable=False
票数 0
EN

Stack Overflow用户

发布于 2021-06-29 05:18:14

代码语言:javascript
复制
model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
for _layer in model:
    if _layer.name == 'distilbert':
        print(f"Freezing model layer {_layer.name}")
        _layer.trainable = False
    print(_layer.name)
    print(_layer.trainable)
---
Freezing model layer distilbert
distilbert
False      <----------------
pre_classifier
True
classifier
True
dropout_99
True

Model: "tf_distil_bert_for_sequence_classification_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
distilbert (TFDistilBertMain multiple                  66362880  
_________________________________________________________________
pre_classifier (Dense)       multiple                  590592    
_________________________________________________________________
classifier (Dense)           multiple                  1538      
_________________________________________________________________
dropout_99 (Dropout)         multiple                  0         
=================================================================
Total params: 66,955,010
Trainable params: 592,130
Non-trainable params: 66,362,880   <-----

不结冰。

代码语言:javascript
复制
Model: "tf_distil_bert_for_sequence_classification_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
distilbert (TFDistilBertMain multiple                  66362880  
_________________________________________________________________
pre_classifier (Dense)       multiple                  590592    
_________________________________________________________________
classifier (Dense)           multiple                  1538      
_________________________________________________________________
dropout_59 (Dropout)         multiple                  0         
=================================================================
Total params: 66,955,010
Trainable params: 66,955,010
Non-trainable params: 0

请相应地将TFDistilBertForSequenceClassification改为TFBertForSequenceClassification。为此,首先运行model.summary以验证基本名称。对于TFDistilBertForSequenceClassification,它是distilbert

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62671668

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档