我改变了
if self.l2_norm:
norm = torch.norm(masked_embedding, p=2, dim=1) + 1e-10
masked_embedding = masked_embedding / norm.expand_as(masked_embedding)至
if self.l2_norm:
masked_embedding = torch.nn.functional.normalize(masked_embedding, p=2.0, dim=2, eps=1e-10, out=None)现在,我得到了这个新错误(以前得到了一个不同的错误,因此不得不将其更改为so):
(fashcomp) [jalal@goku fashion-compatibility]$ python main.py --name test_baseline --learned --l2_embed --datadir ../../../data/fashion/
/scratch3/venv/fashcomp/lib/python3.8/site-packages/torchvision/transforms/transforms.py:310: UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
warnings.warn("The use of the transforms.Scale transform is deprecated, " +
+ Number of params: 3191808
<class 'torch.utils.data.dataloader.DataLoader'>
/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at /pytorch/c10/core/TensorImpl.h:1156.)
return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
Traceback (most recent call last):
File "main.py", line 324, in <module>
main()
File "main.py", line 167, in main
train(train_loader, tnet, criterion, optimizer, epoch)
File "main.py", line 202, in train
acc, loss_triplet, loss_mask, loss_embed, loss_vse, loss_sim_t, loss_sim_i = tnet(anchor, far, close)
File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/scratch3/research/code/fashion/fashion-compatibility/tripletnet.py", line 146, in forward
acc, loss_triplet, loss_sim_i, loss_mask, loss_embed, general_x, general_y, general_z = self.image_forward(x, y, z)
File "/scratch3/research/code/fashion/fashion-compatibility/tripletnet.py", line 74, in image_forward
embedded_x, masknorm_norm_x, embed_norm_x, general_x = self.embeddingnet(x.images, c)
File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/scratch3/research/code/fashion/fashion-compatibility/type_specific_network.py", line 147, in forward
masked_embedding = torch.nn.functional.normalize(masked_embedding, p=2.0, dim=2, eps=1e-10, out=None)
File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/nn/functional.py", line 4428, in normalize
denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input)
File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/_tensor.py", line 417, in norm
return torch.norm(self, p, dim, keepdim, dtype=dtype)
File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/functional.py", line 1356, in norm
return _VF.norm(input, p, _dim, keepdim=keepdim) # type: ignore[attr-defined]
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)这段代码以前是用Python2运行的,也是一个更旧的PyTorch版本,可以追溯到3年前。我正在使用本机Python3.8和PyTorch 1.9GPU在CentOS 7中运行它。
$ pip freeze
absl-py==0.13.0
argon2-cffi==20.1.0
attrs==21.2.0
backcall==0.2.0
bleach==4.1.0
cachetools==4.2.2
certifi==2021.5.30
cffi==1.14.6
charset-normalizer==2.0.4
cycler==0.10.0
debugpy==1.4.1
decorator==5.0.9
defusedxml==0.7.1
entrypoints==0.3
google-auth==1.35.0
google-auth-oauthlib==0.4.5
grpcio==1.39.0
h5py==3.3.0
idna==3.2
importlib==1.0.4
ipykernel==6.2.0
ipython==7.26.0
ipython-genutils==0.2.0
ipywidgets==7.6.3
jedi==0.18.0
Jinja2==3.0.1
joblib==1.0.1
jsonschema==3.2.0
jupyter==1.0.0
jupyter-client==7.0.1
jupyter-console==6.4.0
jupyter-core==4.7.1
jupyterlab-pygments==0.1.2
jupyterlab-widgets==1.0.0
kiwisolver==1.3.1
Markdown==3.3.4
MarkupSafe==2.0.1
matplotlib==3.4.3
matplotlib-inline==0.1.2
mistune==0.8.4
nbclient==0.5.4
nbconvert==6.1.0
nbformat==5.1.3
nest-asyncio==1.5.1
notebook==6.4.3
numpy==1.21.2
oauthlib==3.1.1
packaging==21.0
pandas==1.3.2
pandocfilters==1.4.3
parso==0.8.2
pexpect==4.8.0
pickleshare==0.7.5
Pillow==8.3.1
prometheus-client==0.11.0
prompt-toolkit==3.0.20
protobuf==3.17.3
ptyprocess==0.7.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.20
Pygments==2.10.0
pyparsing==2.4.7
pyrsistent==0.18.0
python-dateutil==2.8.2
pytz==2021.1
pyzmq==22.2.1
qtconsole==5.1.1
QtPy==1.10.0
requests==2.26.0
requests-oauthlib==1.3.0
rsa==4.7.2
scikit-learn==0.24.2
scipy==1.7.1
Send2Trash==1.8.0
six==1.16.0
sklearn==0.0
tensorboard==2.6.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
terminado==0.11.1
testpath==0.5.0
threadpoolctl==2.2.0
torch==1.9.0
torch-tb-profiler==0.2.1
torchaudio==0.9.0
torchvision==0.10.0
tornado==6.1
traitlets==5.0.5
typing-extensions==3.10.0.0
urllib3==1.26.6
wcwidth==0.2.5
webencodings==0.5.1
Werkzeug==2.0.1
widgetsnbextension==3.5.1
$ pip freeze
absl-py==0.13.0
argon2-cffi==20.1.0
attrs==21.2.0
backcall==0.2.0
bleach==4.1.0
cachetools==4.2.2
certifi==2021.5.30
cffi==1.14.6
charset-normalizer==2.0.4
cycler==0.10.0
debugpy==1.4.1
decorator==5.0.9
defusedxml==0.7.1
entrypoints==0.3
google-auth==1.35.0
google-auth-oauthlib==0.4.5
grpcio==1.39.0
h5py==3.3.0
idna==3.2
importlib==1.0.4
ipykernel==6.2.0
ipython==7.26.0
ipython-genutils==0.2.0
ipywidgets==7.6.3
jedi==0.18.0
Jinja2==3.0.1
joblib==1.0.1
jsonschema==3.2.0
jupyter==1.0.0
jupyter-client==7.0.1
jupyter-console==6.4.0
jupyter-core==4.7.1
jupyterlab-pygments==0.1.2
jupyterlab-widgets==1.0.0
kiwisolver==1.3.1
Markdown==3.3.4
MarkupSafe==2.0.1
matplotlib==3.4.3
matplotlib-inline==0.1.2
mistune==0.8.4
nbclient==0.5.4
nbconvert==6.1.0
nbformat==5.1.3
nest-asyncio==1.5.1
notebook==6.4.3
numpy==1.21.2
oauthlib==3.1.1
packaging==21.0
pandas==1.3.2
pandocfilters==1.4.3
parso==0.8.2
pexpect==4.8.0
pickleshare==0.7.5
Pillow==8.3.1
prometheus-client==0.11.0
prompt-toolkit==3.0.20
protobuf==3.17.3
ptyprocess==0.7.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.20
Pygments==2.10.0
pyparsing==2.4.7
pyrsistent==0.18.0
python-dateutil==2.8.2
pytz==2021.1
pyzmq==22.2.1
qtconsole==5.1.1
QtPy==1.10.0
requests==2.26.0
requests-oauthlib==1.3.0
rsa==4.7.2
scikit-learn==0.24.2
scipy==1.7.1
Send2Trash==1.8.0
six==1.16.0
sklearn==0.0
tensorboard==2.6.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
terminado==0.11.1
testpath==0.5.0
threadpoolctl==2.2.0
torch==1.9.0
torch-tb-profiler==0.2.1
torchaudio==0.9.0
torchvision==0.10.0
tornado==6.1
traitlets==5.0.5
typing-extensions==3.10.0.0
urllib3==1.26.6
wcwidth==0.2.5
webencodings==0.5.1
Werkzeug==2.0.1
widgetsnbextension==3.5.1GitHub问题和代码可以找到这里。
发布于 2021-09-01 21:40:52
要切换到F.normalize,需要确保在dim=1上应用它
if self.l2_norm:
masked_embedding = F.normalize(masked_embedding, p=2.0, dim=1, eps=1e-10)如果您喜欢在torch.norm或torch.Tensor.norm中使用另一个替代方案。您可以使用选项keepdim=True,它在执行内部规范化时会有所帮助:
if self.l2_norm:
norm = masked_embedding.norm(p=2, dim=1, keepdim=True) + 1e-10
masked_embedding /= normhttps://stackoverflow.com/questions/69020802
复制相似问题