我想替换3D Resnet的线性层,它可以从pytorch集线器下载。
我可以使用以下代码获得线性层的名称:
for name, layer in model.named_modules():
if isinstance(layer, torch.nn.Linear):
print(name, layer)分块5.线性(in_features=2048,out_features=400,bias=True)
我不能简单地使用model.blocks.5.proj = nn.Linear(2048, 10),因为.5.会抛出一个语法错误。相反,我尝试迭代这些模块并替换线性层:
for name, layer in model.named_modules():
if isinstance(layer, torch.nn.Linear):
model._modules[name] = torch.nn.Linear(2048, 10)出于某种原因,这也不起作用。相反,它只是创建了一个同名的额外线性层:
分块5.proj线性(in_features=2048,out_features=400,bias=True)
有人能帮我一下吗?
发布于 2022-05-29 03:04:30
打印层的整数指示blocks是一个nn.Sequential模块。您可以使用常规数组索引访问nn.Sequential模块中的特定层。
试一试如下:
blocks[5].proj = torch.nn.Linear(2048, 10)https://stackoverflow.com/questions/71083563
复制相似问题