我理解扁平移除了除一个维度之外的所有维度。例如,我理解扁平()
> t = torch.ones(4, 3)
> t
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]])
> flatten(t)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])然而,我没有得到Flatten,特别是我没有从医生得到这个片段的意义
>>> input = torch.randn(32, 1, 5, 5)
>>> m = nn.Sequential(
>>> nn.Conv2d(1, 32, 5, 1, 1),
>>> nn.Flatten()
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([32, 288])我觉得输出应该有[160]大小,因为32*5=160。
Q1.所以它输出的是[32,288]大小
Q2. I也不了解文档中给出的shape信息的含义:

Q3.及其参数的含义:

发布于 2021-05-09 17:28:28
这是默认行为的不同之处。默认情况下,torch.flatten会平展所有维度,而默认情况下,torch.nn.Flatten会将从第二维度(索引1)开始的所有维度都压平。
您可以在start_dim和end_dim参数的默认值中看到这种行为。start_dim参数表示要平坦的第一个维度(零索引),而end_dim参数表示要平坦的最后一个维度。因此,当start_dim=1是torch.nn.Flatten的默认值时,第一个维度(索引0)不是扁平的,而是包含在start_dim=0中,这是torch.flatten的缺省值。
这种差异背后的原因可能是因为torch.nn.Flatten打算与torch.nn.Sequential一起使用,通常对一批输入执行一系列操作,其中每个输入都独立于其他输入。例如,如果您有一批图像并调用了torch.nn.Flatten,那么典型的用例是将每个映像分别平放,而不是将整个批处理平平。
如果您确实希望使用torch.nn.Flatten来平平所有维度,您可以简单地将对象创建为torch.nn.Flatten(start_dim=0)。
最后,文档中的形状信息仅涵盖张量的形状将如何受到影响,这说明第一个(索引0)维仍然保持不变。所以,如果你有一个形状(N, *dims)的输入张量,其中*dims是一个任意的维序列,输出张量将具有形状(N, product of *dims),因为除了批处理维外,所有维都是平坦的。例如,形状(3,10,10)的输入将具有形状(3, 10 x 10) = (3, 100)的输出。
https://stackoverflow.com/questions/67460123
复制相似问题