首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何构建一个与autograd兼容的Pytorch模块,它可以像图像一样调整张量的大小?

如何构建一个与autograd兼容的Pytorch模块,它可以像图像一样调整张量的大小?
EN

Stack Overflow用户
提问于 2018-05-18 17:57:13
回答 2查看 1.7K关注 0票数 9

我想知道我是否可以在Pytorch中构建一个图像调整模块,它接受3*H*W的torch.tensor作为输入,并返回一个张量作为调整后的图像。

我知道可以将张量转换为PIL图像并使用torchvision,但我也希望将梯度从调整大小的图像反向传播到原始图像,下面的示例将返回此类错误(在Windows10上的PyTorch 0.4.0中):

代码语言:javascript
复制
import numpy as np
from torchvision import transforms

t2i = transforms.ToPILImage()
i2t = transforms.ToTensor()

trans = transforms.Compose(
    t2i, transforms.Resize(size=200), i2t]
)

test = np.random.normal(size=[3, 300, 300])
test = torch.tensor(test, requires_grad=True)
resized = trans(test)
resized.backward()

print(test.grad)

Traceback (most recent call last):
  File "D:/Projects/Python/PyTorch/test.py", line 41, in <module>
    main()
  File "D:/Projects/Python/PyTorch/test.py", line 33, in main
    resized = trans(test)
  File "D:\Anaconda3\envs\pytorch\lib\site-packages\torchvision\transforms\transforms.py", line 42, in __call__
    img = t(img)
  File "D:\Anaconda3\envs\pytorch\lib\site-packages\torchvision\transforms\transforms.py", line 103, in __call__
    return F.to_pil_image(pic, self.mode)
  File "D:\Anaconda3\envs\pytorch\lib\site-packages\torchvision\transforms\functional.py", line 102, in to_pil_image
    npimg = np.transpose(pic.numpy(), (1, 2, 0))
RuntimeError: Can't call numpy() on Variable that requires grad. Use var.detach().numpy() instead.

似乎我不能在不首先从autograd中分离张量的情况下“调整”它的大小,但是分离它会阻止我计算梯度。

有没有一种方法可以构建一个与torchvision.transforms.Resize做同样事情的torch函数/模块,并且是autograd兼容的?任何帮助都是非常感谢的!

EN

回答 2

Stack Overflow用户

发布于 2018-05-26 23:46:54

torch.nn.functional.upsample为我工作,爸爸!

票数 3
EN

Stack Overflow用户

发布于 2020-08-22 21:00:49

我刚刚弄明白了如何在实现自定义损失函数时保留梯度。

诀窍是将你的结果附加到虚拟渐变

代码语言:javascript
复制
def custom_loss(tensor1, tensor2):
    # convert tensors to PIL image, doing calculation, we have output = 0.123
    grad = (tensor1 + tensor2).sum()
    loss = grad - grad + output
    return loss
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50408673

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档