首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何从Tensorflow的Huggingface中修改基本ViT体系结构

如何从Tensorflow的Huggingface中修改基本ViT体系结构
EN

Stack Overflow用户
提问于 2022-03-15 12:56:36
回答 1查看 426关注 0票数 2

我是新的拥抱脸,并希望采用相同的变压器架构做在ViT的图像分类到我的领域。因此,我需要改变输入形状和所做的增强。

来自拥抱脸的片段:

代码语言:javascript
复制
from transformers import ViTFeatureExtractor, TFViTForImageClassification
import tensorflow as tf
from PIL import Image
import requests

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
model = TFViTForImageClassification.from_pretrained("google/vit-base-patch16-224")

inputs = feature_extractor(images=image, return_tensors="tf")
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 1000 ImageNet classes
predicted_class_idx = tf.math.argmax(logits, axis=-1)[0]
print("Predicted class:", model.config.id2label[int(predicted_class_idx)])

当我做mode.summary()

我得到以下结果:

代码语言:javascript
复制
Model: "tf_vi_t_for_image_classification_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 vit (TFViTMainLayer)        multiple                  85798656  
                                                                 
 classifier (Dense)          multiple                  769000    
                                                                 
=================================================================
Total params: 86,567,656
Trainable params: 86,567,656
Non-trainable params: 0

如图所示,封装了ViT基础的层,是否有一种方法可以打开这些层以允许我修改特定的层?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-03-15 13:20:10

在您的例子中,我建议查看源代码这里并跟踪被调用的类。例如,要获取Embeddings类的层,可以运行:

代码语言:javascript
复制
print(model.layers[0].embeddings.patch_embeddings.projection)
print(model.layers[0].embeddings.dropout)
代码语言:javascript
复制
<keras.layers.convolutional.Conv2D object at 0x7fea6264c6d0>
<keras.layers.core.dropout.Dropout object at 0x7fea62d65110>

或者,如果您想获得第一个Attention块的层,请尝试:

代码语言:javascript
复制
print(model.layers[0].encoder.layer[0].attention.self_attention.query)
print(model.layers[0].encoder.layer[0].attention.self_attention.key)
print(model.layers[0].encoder.layer[0].attention.self_attention.value)
print(model.layers[0].encoder.layer[0].attention.self_attention.dropout)
print(model.layers[0].encoder.layer[0].attention.dense_output.dense)
print(model.layers[0].encoder.layer[0].attention.dense_output.dropout)
代码语言:javascript
复制
<keras.layers.convolutional.Conv2D object at 0x7fea6264c6d0>
<keras.layers.core.dropout.Dropout object at 0x7fea62d65110>
<keras.layers.core.dense.Dense object at 0x7fea62ec7f90>
<keras.layers.core.dense.Dense object at 0x7fea62ec7b50>
<keras.layers.core.dense.Dense object at 0x7fea62ec79d0>
<keras.layers.core.dropout.Dropout object at 0x7fea62cf5c90>
<keras.layers.core.dense.Dense object at 0x7fea62cf5250>
<keras.layers.core.dropout.Dropout object at 0x7fea62cf5410>

诸若此类。

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71482661

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档