首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >理解torch.nn.Flatten

理解torch.nn.Flatten
EN

Stack Overflow用户
提问于 2021-05-09 16:40:05
回答 1查看 3.5K关注 0票数 3

我理解扁平移除了除一个维度之外的所有维度。例如,我理解扁平()

代码语言:javascript
复制
> t = torch.ones(4, 3)
> t
tensor([[1., 1., 1.],
    [1., 1., 1.],
    [1., 1., 1.],
    [1., 1., 1.]])

> flatten(t)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

然而,我没有得到Flatten,特别是我没有从医生得到这个片段的意义

代码语言:javascript
复制
>>> input = torch.randn(32, 1, 5, 5)
>>> m = nn.Sequential(
>>>     nn.Conv2d(1, 32, 5, 1, 1),
>>>     nn.Flatten()
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([32, 288])

我觉得输出应该有[160]大小,因为32*5=160

Q1.所以它输出的是[32,288]大小

Q2. I也不了解文档中给出的shape信息的含义:

Q3.及其参数的含义:

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-05-09 17:28:28

这是默认行为的不同之处。默认情况下,torch.flatten会平展所有维度,而默认情况下,torch.nn.Flatten会将从第二维度(索引1)开始的所有维度都压平。

您可以在start_dimend_dim参数的默认值中看到这种行为。start_dim参数表示要平坦的第一个维度(零索引),而end_dim参数表示要平坦的最后一个维度。因此,当start_dim=1torch.nn.Flatten的默认值时,第一个维度(索引0)不是扁平的,而是包含在start_dim=0中,这是torch.flatten的缺省值。

这种差异背后的原因可能是因为torch.nn.Flatten打算与torch.nn.Sequential一起使用,通常对一批输入执行一系列操作,其中每个输入都独立于其他输入。例如,如果您有一批图像并调用了torch.nn.Flatten,那么典型的用例是将每个映像分别平放,而不是将整个批处理平平。

如果您确实希望使用torch.nn.Flatten来平平所有维度,您可以简单地将对象创建为torch.nn.Flatten(start_dim=0)

最后,文档中的形状信息仅涵盖张量的形状将如何受到影响,这说明第一个(索引0)维仍然保持不变。所以,如果你有一个形状(N, *dims)的输入张量,其中*dims是一个任意的维序列,输出张量将具有形状(N, product of *dims),因为除了批处理维外,所有维都是平坦的。例如,形状(3,10,10)的输入将具有形状(3, 10 x 10) = (3, 100)的输出。

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/67460123

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档