首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >返回_VF.norm(input,p,_dim,keepdim=keepdim) IndexError:维度超出范围(预期在[-2,1]范围内,但得到2)

返回_VF.norm(input,p,_dim,keepdim=keepdim) IndexError:维度超出范围(预期在[-2,1]范围内,但得到2)
EN

Stack Overflow用户
提问于 2021-09-01 21:34:42
回答 1查看 203关注 0票数 0

我改变了

代码语言:javascript
复制
if self.l2_norm:
    norm = torch.norm(masked_embedding, p=2, dim=1) + 1e-10
    masked_embedding = masked_embedding / norm.expand_as(masked_embedding)

代码语言:javascript
复制
if self.l2_norm:

    masked_embedding = torch.nn.functional.normalize(masked_embedding, p=2.0, dim=2, eps=1e-10, out=None)

现在,我得到了这个新错误(以前得到了一个不同的错误,因此不得不将其更改为so):

代码语言:javascript
复制
(fashcomp) [jalal@goku fashion-compatibility]$ python main.py --name test_baseline --learned --l2_embed --datadir ../../../data/fashion/
/scratch3/venv/fashcomp/lib/python3.8/site-packages/torchvision/transforms/transforms.py:310: UserWarning: The use of the transforms.Scale transform is deprecated, please use transforms.Resize instead.
  warnings.warn("The use of the transforms.Scale transform is deprecated, " +
  + Number of params: 3191808
<class 'torch.utils.data.dataloader.DataLoader'>
/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/nn/functional.py:718: UserWarning: Named tensors and all their associated APIs are an experimental feature and subject to change. Please do not use them for anything important until they are released as stable. (Triggered internally at  /pytorch/c10/core/TensorImpl.h:1156.)
  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)
Traceback (most recent call last):
  File "main.py", line 324, in <module>
    main()    
  File "main.py", line 167, in main
    train(train_loader, tnet, criterion, optimizer, epoch)
  File "main.py", line 202, in train
    acc, loss_triplet, loss_mask, loss_embed, loss_vse, loss_sim_t, loss_sim_i = tnet(anchor, far, close)
  File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/scratch3/research/code/fashion/fashion-compatibility/tripletnet.py", line 146, in forward
    acc, loss_triplet, loss_sim_i, loss_mask, loss_embed, general_x, general_y, general_z = self.image_forward(x, y, z)
  File "/scratch3/research/code/fashion/fashion-compatibility/tripletnet.py", line 74, in image_forward
    embedded_x, masknorm_norm_x, embed_norm_x, general_x = self.embeddingnet(x.images, c)
  File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/scratch3/research/code/fashion/fashion-compatibility/type_specific_network.py", line 147, in forward
    masked_embedding = torch.nn.functional.normalize(masked_embedding, p=2.0, dim=2, eps=1e-10, out=None)
  File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/nn/functional.py", line 4428, in normalize
    denom = input.norm(p, dim, keepdim=True).clamp_min(eps).expand_as(input)
  File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/_tensor.py", line 417, in norm
    return torch.norm(self, p, dim, keepdim, dtype=dtype)
  File "/scratch3/venv/fashcomp/lib/python3.8/site-packages/torch/functional.py", line 1356, in norm
    return _VF.norm(input, p, _dim, keepdim=keepdim)  # type: ignore[attr-defined]
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)

这段代码以前是用Python2运行的,也是一个更旧的PyTorch版本,可以追溯到3年前。我正在使用本机Python3.8和PyTorch 1.9GPU在CentOS 7中运行它。

代码语言:javascript
复制
$ pip freeze
absl-py==0.13.0
argon2-cffi==20.1.0
attrs==21.2.0
backcall==0.2.0
bleach==4.1.0
cachetools==4.2.2
certifi==2021.5.30
cffi==1.14.6
charset-normalizer==2.0.4
cycler==0.10.0
debugpy==1.4.1
decorator==5.0.9
defusedxml==0.7.1
entrypoints==0.3
google-auth==1.35.0
google-auth-oauthlib==0.4.5
grpcio==1.39.0
h5py==3.3.0
idna==3.2
importlib==1.0.4
ipykernel==6.2.0
ipython==7.26.0
ipython-genutils==0.2.0
ipywidgets==7.6.3
jedi==0.18.0
Jinja2==3.0.1
joblib==1.0.1
jsonschema==3.2.0
jupyter==1.0.0
jupyter-client==7.0.1
jupyter-console==6.4.0
jupyter-core==4.7.1
jupyterlab-pygments==0.1.2
jupyterlab-widgets==1.0.0
kiwisolver==1.3.1
Markdown==3.3.4
MarkupSafe==2.0.1
matplotlib==3.4.3
matplotlib-inline==0.1.2
mistune==0.8.4
nbclient==0.5.4
nbconvert==6.1.0
nbformat==5.1.3
nest-asyncio==1.5.1
notebook==6.4.3
numpy==1.21.2
oauthlib==3.1.1
packaging==21.0
pandas==1.3.2
pandocfilters==1.4.3
parso==0.8.2
pexpect==4.8.0
pickleshare==0.7.5
Pillow==8.3.1
prometheus-client==0.11.0
prompt-toolkit==3.0.20
protobuf==3.17.3
ptyprocess==0.7.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.20
Pygments==2.10.0
pyparsing==2.4.7
pyrsistent==0.18.0
python-dateutil==2.8.2
pytz==2021.1
pyzmq==22.2.1
qtconsole==5.1.1
QtPy==1.10.0
requests==2.26.0
requests-oauthlib==1.3.0
rsa==4.7.2
scikit-learn==0.24.2
scipy==1.7.1
Send2Trash==1.8.0
six==1.16.0
sklearn==0.0
tensorboard==2.6.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
terminado==0.11.1
testpath==0.5.0
threadpoolctl==2.2.0
torch==1.9.0
torch-tb-profiler==0.2.1
torchaudio==0.9.0
torchvision==0.10.0
tornado==6.1
traitlets==5.0.5
typing-extensions==3.10.0.0
urllib3==1.26.6
wcwidth==0.2.5
webencodings==0.5.1
Werkzeug==2.0.1
widgetsnbextension==3.5.1
$ pip freeze
absl-py==0.13.0
argon2-cffi==20.1.0
attrs==21.2.0
backcall==0.2.0
bleach==4.1.0
cachetools==4.2.2
certifi==2021.5.30
cffi==1.14.6
charset-normalizer==2.0.4
cycler==0.10.0
debugpy==1.4.1
decorator==5.0.9
defusedxml==0.7.1
entrypoints==0.3
google-auth==1.35.0
google-auth-oauthlib==0.4.5
grpcio==1.39.0
h5py==3.3.0
idna==3.2
importlib==1.0.4
ipykernel==6.2.0
ipython==7.26.0
ipython-genutils==0.2.0
ipywidgets==7.6.3
jedi==0.18.0
Jinja2==3.0.1
joblib==1.0.1
jsonschema==3.2.0
jupyter==1.0.0
jupyter-client==7.0.1
jupyter-console==6.4.0
jupyter-core==4.7.1
jupyterlab-pygments==0.1.2
jupyterlab-widgets==1.0.0
kiwisolver==1.3.1
Markdown==3.3.4
MarkupSafe==2.0.1
matplotlib==3.4.3
matplotlib-inline==0.1.2
mistune==0.8.4
nbclient==0.5.4
nbconvert==6.1.0
nbformat==5.1.3
nest-asyncio==1.5.1
notebook==6.4.3
numpy==1.21.2
oauthlib==3.1.1
packaging==21.0
pandas==1.3.2
pandocfilters==1.4.3
parso==0.8.2
pexpect==4.8.0
pickleshare==0.7.5
Pillow==8.3.1
prometheus-client==0.11.0
prompt-toolkit==3.0.20
protobuf==3.17.3
ptyprocess==0.7.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pycparser==2.20
Pygments==2.10.0
pyparsing==2.4.7
pyrsistent==0.18.0
python-dateutil==2.8.2
pytz==2021.1
pyzmq==22.2.1
qtconsole==5.1.1
QtPy==1.10.0
requests==2.26.0
requests-oauthlib==1.3.0
rsa==4.7.2
scikit-learn==0.24.2
scipy==1.7.1
Send2Trash==1.8.0
six==1.16.0
sklearn==0.0
tensorboard==2.6.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.0
terminado==0.11.1
testpath==0.5.0
threadpoolctl==2.2.0
torch==1.9.0
torch-tb-profiler==0.2.1
torchaudio==0.9.0
torchvision==0.10.0
tornado==6.1
traitlets==5.0.5
typing-extensions==3.10.0.0
urllib3==1.26.6
wcwidth==0.2.5
webencodings==0.5.1
Werkzeug==2.0.1
widgetsnbextension==3.5.1

GitHub问题和代码可以找到这里

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-09-01 21:40:52

要切换到F.normalize,需要确保在dim=1上应用它

代码语言:javascript
复制
if self.l2_norm:
    masked_embedding = F.normalize(masked_embedding, p=2.0, dim=1, eps=1e-10)

如果您喜欢在torch.normtorch.Tensor.norm中使用另一个替代方案。您可以使用选项keepdim=True,它在执行内部规范化时会有所帮助:

代码语言:javascript
复制
if self.l2_norm:
    norm = masked_embedding.norm(p=2, dim=1, keepdim=True) + 1e-10
    masked_embedding /= norm
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69020802

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档