首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >用于ViT实现的补丁编码器

用于ViT实现的补丁编码器
EN

Stack Overflow用户
提问于 2022-07-24 02:46:54
回答 1查看 94关注 0票数 0

我正在学习这个链接的视觉变压器。我无法理解实现步骤2.3:修补程序编码器,即:

代码语言:javascript
复制
    def __init__(self, num_patches, projection_dim):
        super(PatchEncoder, self).__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )
 
    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

任何人请帮助我理解这个功能到底是做什么的。

EN

回答 1

Stack Overflow用户

发布于 2022-08-12 13:59:26

PatchEncoder将一个扁平的补丁作为输入,线性地将其投影到所需的输入维度(即projection_dim),并向每个补丁添加位置嵌入。

我们将以cifar10为例。

原始图像None、32、32、3被data_augmentation调整为None、72、72、3。

Patches中,调整大小的图像除以补丁大小6,因此扁平的补丁为无、12、12、108。

最后,进行整形,因此PatchEncoder的输入将为None,144,108。

PatchEncoder中,首先应用稠密层并将其投影到projection_dim维数,结果是0、144、64。

position_embedding (定义为layers.Embedding)是将tf.range生成的补丁的序列号转换为一个projection_dim维向量的过程,该向量被添加到投影补丁中,最终输出的encoded为None、144、64。

注意:没有一个表示任意批次大小。

我希望这有助于理解。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73095449

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档