如果我使用的是tensorflow版本的拥抱面变压器,我如何冻结预先训练的编码器的权重,以便只优化头部层的权重?
对于PyTorch实现,它是通过
for param in model.base_model.parameters():
param.requires_grad = False也希望对tensorflow实现做同样的工作。
发布于 2020-07-07 01:47:09
找到了办法。编译前冻结基本模型。
model = TFBertForSequenceClassification.from_pretrained("bert-base-uncased")
model.layers[0].trainable = False
model.compile(...)发布于 2021-02-09 04:44:30
在挖掘了这个线程1之后,我认为下面的代码对TF2不会有什么影响。即使在特定的情况下它可能是多余的。
model = TFBertModel.from_pretrained('./bert-base-uncase')
for layer in model.layers:
layer.trainable=False
for w in layer.weights: w._trainable=False发布于 2021-06-29 05:18:14
model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
for _layer in model:
if _layer.name == 'distilbert':
print(f"Freezing model layer {_layer.name}")
_layer.trainable = False
print(_layer.name)
print(_layer.trainable)
---
Freezing model layer distilbert
distilbert
False <----------------
pre_classifier
True
classifier
True
dropout_99
True
Model: "tf_distil_bert_for_sequence_classification_4"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
distilbert (TFDistilBertMain multiple 66362880
_________________________________________________________________
pre_classifier (Dense) multiple 590592
_________________________________________________________________
classifier (Dense) multiple 1538
_________________________________________________________________
dropout_99 (Dropout) multiple 0
=================================================================
Total params: 66,955,010
Trainable params: 592,130
Non-trainable params: 66,362,880 <-----不结冰。
Model: "tf_distil_bert_for_sequence_classification_2"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
distilbert (TFDistilBertMain multiple 66362880
_________________________________________________________________
pre_classifier (Dense) multiple 590592
_________________________________________________________________
classifier (Dense) multiple 1538
_________________________________________________________________
dropout_59 (Dropout) multiple 0
=================================================================
Total params: 66,955,010
Trainable params: 66,955,010
Non-trainable params: 0请相应地将TFDistilBertForSequenceClassification改为TFBertForSequenceClassification。为此,首先运行model.summary以验证基本名称。对于TFDistilBertForSequenceClassification,它是distilbert。
https://stackoverflow.com/questions/62671668
复制相似问题