首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >把所有的张量放在一个装置上

把所有的张量放在一个装置上
EN

Stack Overflow用户
提问于 2022-08-15 15:33:00
回答 1查看 30关注 0票数 1

我在我的模型中使用ViViT。虽然我将输入和整个模型移到cuda,但火车过程在位置嵌入线上显示了一个错误:

代码语言:javascript
复制
class ViViTBackbone(nn.Module):
    """ Model-3 backbone of ViViT """

    def __init__(self, t, h, w, patch_t, patch_h, patch_w, num_classes, dim, depth, heads, mlp_dim, dim_head=3,
                 channels=3, mode='tubelet', emb_dropout=0., dropout=0., model=3):
        super().__init__()

        assert t % patch_t == 0 and h % patch_h == 0 and w % patch_w == 0, "Video dimensions should be divisible by " \
                                                                           "tubelet size "

        self.T = t
        self.H = h
        self.W = w
        self.channels = channels
        self.t = patch_t
        self.h = patch_h
        self.w = patch_w
        self.mode = mode

        self.nt = self.T // self.t
        self.nh = self.H // self.h
        self.nw = self.W // self.w

        tubelet_dim = self.t * self.h * self.w * channels

        self.to_tubelet_embedding = nn.Sequential(
            Rearrange('b c (t pt) (h ph) (w pw) -> b t (h w) (pt ph pw c)', pt=self.t, ph=self.h, pw=self.w),
            nn.Linear(tubelet_dim, dim)
        )

        # repeat same spatial position encoding temporally
        self.pos_embedding = nn.Parameter(torch.randn(1, 1, self.nh * self.nw, dim)).repeat(1, self.nt, 1, 1)

        self.dropout = nn.Dropout(emb_dropout)

        if model == 3:
            self.transformer = FSATransformerEncoder(dim, depth, heads, dim_head, mlp_dim,
                                                     self.nt, self.nh, self.nw, dropout)
        elif model == 4:
            assert heads % 2 == 0, "Number of heads should be even"
            self.transformer = FDATransformerEncoder(dim, depth, heads, dim_head, mlp_dim,
                                                     self.nt, self.nh, self.nw, dropout)

        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, x):
        """ x is a video: (b, C, T, H, W) """


        tokens = self.to_tubelet_embedding(x)

        tokens += self.pos_embedding   #The error is because of this line
        tokens = self.dropout(tokens)

        x = self.transformer(tokens)
        return x

这是一个错误:

我根据模型类中的以下方法创建ViViT:

代码语言:javascript
复制
self.vivit_FSA_F_8 = ViViTBackbone(t=8, h=16, w=24,   patch_t=1,   patch_h=16,   patch_w=24,  num_classes=10,  dim=128,
                            depth=6,  heads=10,  mlp_dim=8,   model=3)

我怎么才能解决呢?

EN

回答 1

Stack Overflow用户

发布于 2022-08-15 16:18:50

有多种方法:而不是创建如下的参数:

代码语言:javascript
复制
self.T = t

做:

代码语言:javascript
复制
self.T = nn.Parameter(t)

然后model.to(设备)也会将所有参数推送到正确的设备上。

另一种方法是在创建张量时使用设备参数。

代码语言:javascript
复制
some_tensor = torch.tensor(1.0,device=self.device)

代码语言:javascript
复制
some_tensor = torch.ones([3,4],device=self.device)
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73363204

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档