我从这里下载了代码https://github.com/SpaceNetChallenge/SpaceNet_SAR_Buildings_Solutions,特别是使用模型1,我下载了相应的权重,并创建了以下文件来加载模型和测试。首先,我将main.py中的Unet部分复制到一个单独的文件umodel.py和测试文件中,如下所示
import torch
exec(open("./umodel.py").read())
network_data = torch.load('snapshot_fold_8_best')
print(network_data.keys())
import sys
sys.path.append("geffnet")
class Namespace:
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
args = Namespace(extra_num = 1,
dec_ch = [32, 64, 128, 256, 1024],
stride=32,
net='b5',
bot1x1=True,
glob=True,
bn=True,
aspp=True,
ocr=True,
aux_scale=True)
def load_state_dict(model, state_dict):
missing_keys = []
# from UnetOS.umodel import Unet
exec(open("./umodel.py").read())
try:
from torch.hub import load_state_dict_from_url
except ImportError:
from torch.utils.model_zoo import load_url as load_state_dict_from_url
# from UnetOS.umodel import *
model = Unet(extra_num = args.extra_num, dec_ch = args.dec_ch, stride=args.stride, net=args.net, bot1x1 = args.bot1x1, glob=args.glob, bn=args.bn, aspp=args.aspp,
ocr=args.ocr, aux = args.aux_scale > 0).cuda()
load_state_dict(model, network_data)我的问题是,为什么exec(open("./umodel.py").read())工作得很好,但是每当我试图导入from umodel import Unet时,它都有错误。
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
~\AppData\Local\Temp/ipykernel_10492/1282530406.py in <module>
9 # ah
10 model = Unet(extra_num = args.extra_num, dec_ch = args.dec_ch, stride=args.stride, net=args.net, bot1x1 = args.bot1x1, glob=args.glob, bn=args.bn, aspp=args.aspp,
---> 11 ocr=args.ocr, aux = args.aux_scale > 0).cuda()
12 #model = Unet()
13 #print(network_data.key())
D:\hines\Pretrained\1-zbigniewwojna\UnetOS\umodel.py in __init__(self, extra_num, dec_ch, stride, net, bot1x1, glob, bn, aspp, ocr, aux)
238 ['ir_r4_k5_s2_e6_c192_se0.25'],
239 ['ir_r1_k3_s1_e6_c320_se0.25']]
--> 240 enc = GenEfficientNet(in_chans=3, block_args=decode_arch_def(arch_def, depth_multiplier),
241 num_features=round_channels(1280, channel_multiplier, 8, None), stem_size=32,
242 channel_multiplier=channel_multiplier, act_layer=resolve_act_layer({}, 'swish'),
NameError: name 'decode_arch_def' is not defined
发布于 2022-02-10 01:59:47
从错误消息来看,decode_arch_def似乎不可用,并且查看您的导入,这必须来自from geffnet.efficientnet_builder import * (它做了https://github.com/rwightman/gen-efficientnet-pytorch/blob/master/geffnet/efficientnet_builder.py)
您的exec必须工作,因为它遵循类似的导入,这使decode_arch_def在作用域中-- exec()在当前作用域中执行代码,因此它将工作,因为在该范围中decode_arch_def已经定义。
但是,在导入时,导入的模块本身不具备您需要的作用域。您应该将所需的import语句添加到要导入的文件中,以便将它们带入作用域,并且应该可以工作。
例如,使用包含以下内容的mod.py:
def my_func():
print(datetime.now())这样做是可行的:
from datetime import datetime
exec(open("./mod.py").read())
my_func()但这并不意味着:
from datetime import datetime
import mod
mod.my_func()要做到这一点,mod.py必须是:
from datetime import datetime
def my_func():
print(datetime.now())而且在主程序中不需要导入datetime,因为没有引用它。您的代码也有类似的问题--您需要确定Unet类的所有依赖项并导入它们。
https://stackoverflow.com/questions/71057392
复制相似问题