我在我的模型中使用ViViT。虽然我将输入和整个模型移到cuda,但火车过程在位置嵌入线上显示了一个错误:
class ViViTBackbone(nn.Module):
""" Model-3 backbone of ViViT """
def __init__(self, t, h, w, patch_t, patch_h, patch_w, num_classes, dim, depth, heads, mlp_dim, dim_head=3,
channels=3, mode='tubelet', emb_dropout=0., dropout=0., model=3):
super().__init__()
assert t % patch_t == 0 and h % patch_h == 0 and w % patch_w == 0, "Video dimensions should be divisible by " \
"tubelet size "
self.T = t
self.H = h
self.W = w
self.channels = channels
self.t = patch_t
self.h = patch_h
self.w = patch_w
self.mode = mode
self.nt = self.T // self.t
self.nh = self.H // self.h
self.nw = self.W // self.w
tubelet_dim = self.t * self.h * self.w * channels
self.to_tubelet_embedding = nn.Sequential(
Rearrange('b c (t pt) (h ph) (w pw) -> b t (h w) (pt ph pw c)', pt=self.t, ph=self.h, pw=self.w),
nn.Linear(tubelet_dim, dim)
)
# repeat same spatial position encoding temporally
self.pos_embedding = nn.Parameter(torch.randn(1, 1, self.nh * self.nw, dim)).repeat(1, self.nt, 1, 1)
self.dropout = nn.Dropout(emb_dropout)
if model == 3:
self.transformer = FSATransformerEncoder(dim, depth, heads, dim_head, mlp_dim,
self.nt, self.nh, self.nw, dropout)
elif model == 4:
assert heads % 2 == 0, "Number of heads should be even"
self.transformer = FDATransformerEncoder(dim, depth, heads, dim_head, mlp_dim,
self.nt, self.nh, self.nw, dropout)
self.to_latent = nn.Identity()
self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, num_classes)
)
def forward(self, x):
""" x is a video: (b, C, T, H, W) """
tokens = self.to_tubelet_embedding(x)
tokens += self.pos_embedding #The error is because of this line
tokens = self.dropout(tokens)
x = self.transformer(tokens)
return x这是一个错误:

我根据模型类中的以下方法创建ViViT:
self.vivit_FSA_F_8 = ViViTBackbone(t=8, h=16, w=24, patch_t=1, patch_h=16, patch_w=24, num_classes=10, dim=128,
depth=6, heads=10, mlp_dim=8, model=3)我怎么才能解决呢?
发布于 2022-08-15 16:18:50
有多种方法:而不是创建如下的参数:
self.T = t做:
self.T = nn.Parameter(t)然后model.to(设备)也会将所有参数推送到正确的设备上。
另一种方法是在创建张量时使用设备参数。
some_tensor = torch.tensor(1.0,device=self.device)或
some_tensor = torch.ones([3,4],device=self.device)https://stackoverflow.com/questions/73363204
复制相似问题