当我试图用torch.hub的DINO骨干对线性层进行序列化时,我得到了运行时错误。
import torch
import torch.distributed as dist
class LinearClassifier(torch.nn.Module):
def __init__(self, dim, num_labels=1000):
super(LinearClassifier, self).__init__()
self.num_labels = num_labels
self.linear = torch.nn.Linear(dim, num_labels)
self.linear.weight.data.normal_(mean=0.0, std=0.01)
self.linear.bias.data.zero_()
def forward(self, x):
# flatten
x = x.view(x.size(0), -1)
# linear layer
return self.linear(x)
dist.init_process_group('gloo', init_method='file:///tmp/somefile', rank=0, world_size=1)
# load backbone
model = torch.hub.load('facebookresearch/dino:main', 'dino_vits8')
#Setup linear layer
linear_classifier = LinearClassifier(1536, 1000)
linear_classifier = linear_classifier.cuda()
linear_classifier = torch.nn.parallel.DistributedDataParallel(linear_classifier)
state_dict = torch.hub.load_state_dict_from_url(url="https://dl.fbaipublicfiles.com/dino/dino_deitsmall8_pretrain/dino_deitsmall8_linearweights.pth")['state_dict']
linear_classifier.load_state_dict(state_dict, strict=True)
#Sequentialise
model = torch.nn.Sequential(model,
linear_classifier)
x = torch.ones((1, 3, 224, 224))
out = model(x)
print("out: " + out)下面是我的顺序化模型的最后几个层的打印:最后一层印刷
发布于 2022-01-14 13:53:17
它类似于model(x)的输出(由model = torch.hub...定义)具有形状1 x 384,但是您的linear_classifier需要一些形状_ x 1536,这就是为什么您将得到这个错误。因此,您只需通过设置
linear_classifier = LinearClassifier(384, 1000)https://stackoverflow.com/questions/70710923
复制相似问题