首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >预训练火炬模型中的替换层

预训练火炬模型中的替换层
EN

Stack Overflow用户
提问于 2022-02-11 16:41:14
回答 1查看 484关注 0票数 1

我想替换3D Resnet的线性层,它可以从pytorch集线器下载。

我可以使用以下代码获得线性层的名称:

代码语言:javascript
复制
for name, layer in model.named_modules():
    if isinstance(layer, torch.nn.Linear):
        print(name, layer)

分块5.线性(in_features=2048,out_features=400,bias=True)

我不能简单地使用model.blocks.5.proj = nn.Linear(2048, 10),因为.5.会抛出一个语法错误。相反,我尝试迭代这些模块并替换线性层:

代码语言:javascript
复制
for name, layer in model.named_modules():
    if isinstance(layer, torch.nn.Linear):
        model._modules[name] = torch.nn.Linear(2048, 10)

出于某种原因,这也不起作用。相反,它只是创建了一个同名的额外线性层:

分块5.proj线性(in_features=2048,out_features=400,bias=True)

有人能帮我一下吗?

EN

回答 1

Stack Overflow用户

发布于 2022-05-29 03:04:30

打印层的整数指示blocks是一个nn.Sequential模块。您可以使用常规数组索引访问nn.Sequential模块中的特定层。

试一试如下:

代码语言:javascript
复制
blocks[5].proj = torch.nn.Linear(2048, 10)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71083563

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档