首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >为什么`nn.Unfold`不给我和`.unfold`一样的结果?

为什么`nn.Unfold`不给我和`.unfold`一样的结果?
EN

Stack Overflow用户
提问于 2022-08-08 10:04:11
回答 1查看 191关注 0票数 1

为什么这会给我两个完全不同的答案?如何获得与使用方法2的方法1相同的结果?

代码语言:javascript
复制
import torch
from torch import nn

kernel_size = 7
stride = 1

# approach 1
data = torch.rand(4, 64, 174, 120)
data1 = data.unfold(3, kernel_size * 2 + 1, stride)
print(data1.shape)

# approach 2
data = torch.rand(4, 64, 174, 120)
unfold = nn.Unfold(3, kernel_size * 2 + 1, stride)
data2 = unfold(data)
print(data2.shape)

输出:

代码语言:javascript
复制
torch.Size([4, 64, 174, 106, 15])
torch.Size([4, 576, 13432])

编辑

我试过你的方法了。形状是相同的,但内容不是。知道为什么吗?

代码语言:javascript
复制
import torch
from torch import nn

kernel_size = 7
stride = 1

# approach 1
data = torch.rand(4, 64, 174, 120)
data1 = data.unfold(3, kernel_size * 2 + 1, stride)
print(data1.shape)

# approach 2
data = torch.rand(4, 64, 174, 120)
b, c, h, w = data.shape
unfold = nn.Unfold(kernel_size=(1, 2*kernel_size + 1), dilation=1, stride=1, padding=0)
data2 = unfold(data.reshape(-1, 1, 1, w)).permute(0, 2, 1).reshape(b, c, h, -1, 2*kernel_size + 1)
print(data2.shape)

print(torch.equal(data1, data2))

输出:

代码语言:javascript
复制
torch.Size([4, 64, 174, 106, 15])
torch.Size([4, 64, 174, 106, 15])
False
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-08-08 11:04:09

torch.unfold沿着某个维度展开。在您的示例中,它获取了dim 120的4x64x174样本,并提取了所有重叠的15个窗口,其结果是形状为4x64x174x106x15的data1

相反,nn.Unfold工作在bxcx...张量和提取空间补丁。在您的示例中,nn.Unfold得到了kernel_size=3dilation=kernel_size*2+1padding=1。因此,它提取了13432个64个通道的3x3斑块(3_3_64=576),得到了形状为3_3_64=576的data2

要从torch.unfold中获得相同的nn.Unfold输出,您需要对其进行整形和修改:

代码语言:javascript
复制
b, c, h, w = data.shape
unfold = nn.Unfold(kernel_size=(1, 2*kernel_size + 1), dilation=1, stride=1, padding=0)
data2 = unfold(data.reshape(-1, 1, 1, w)).permute(0, 2, 1).reshape(b, c, h, -1, 2*kernel_size + 1)

请仔细阅读nn.Unfold的文档,因为它的工作方式与torch.unfold完全不同。有关nn.Unfoldnn.Fold的更多信息,请参见this thread

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/73276139

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档