torch.flatten()和torch.nn.Flatten()有什么区别?
发布于 2021-02-01 14:14:25
在PyTorch中有三种形式的扁平化
torch.Tensor.flatten直接应用于张量:x.flatten()。torch.flatten应用于:torch.flatten(x)。nn.Module) nn.Flatten()。一般用于模型定义中。这三者都是相同的,并且共享相同的实现,唯一的区别是nn.Flatten默认将start_dim设置为1,以避免将第一个轴(通常是批处理轴)扁平化。而另外两个张量则从axis=0到axis=-1 --即整个张量--如果不给出参数的话。
发布于 2021-02-01 20:34:45
您可以认为torch.flatten()的工作只是简单地对张量进行扁平操作,而不附加任何附加条件。你给一个张量,它变平,并返回它。到此为止了。
相反,nn.Flatten()要复杂得多(也就是说,它是一个神经网络层)。由于是面向对象的,所以它继承了nn.Module,尽管它是用来扁平张量的方法。您可以更多地将其看作是torch.flatten()上的语法糖。
重要差异:一个值得注意的区别是,如果输入至少为1D或更大,则torch.flatten()总是返回一维张量,而nn.Flatten()总是返回2D张量,条件是输入至少为2D或更大(以一维张量作为输入,则会抛出IndexError)。
比较:
torch.flatten()是一个API,而nn.Flatten()是一个神经网络层。torch.flatten()是python函数,而nn.Flatten()是python类。torch.flatten()可以在野外使用(例如,用于简单张量操作),而nn.Flatten()则被期望作为层之一在nn.Sequential()块中使用。torch.flatten()没有关于计算图的信息,除非它被卡在其他具有图形感知的块中( tensor.requires_grad标志设置为True),而nn.Flatten()总是被自动梯度跟踪。torch.flatten()不能接受和处理(例如线性/卷积)层作为输入,而nn.Flatten()主要用于处理这些神经网络层。torch.flatten()和nn.Flatten()都将视图返回到输入张量。因此,对结果的任何修改也会影响输入张量。(见下面的代码)代码演示
# input tensors to work with
In [109]: t1 = torch.arange(12).reshape(3, -1)
In [110]: t2 = torch.arange(12, 24).reshape(3, -1)
In [111]: t3 = torch.arange(12, 36).reshape(3, 2, -1) # 3D tensor用torch.flatten()进行压扁
In [113]: t1flat = torch.flatten(t1)
In [114]: t1flat
Out[114]: tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
# modification to the flattened tensor
In [115]: t1flat[-1] = -1
# input tensor is also modified; thus flattening is a view.
In [116]: t1
Out[116]:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, -1]])用nn.Flatten()进行压扁
In [123]: nnfl = nn.Flatten()
In [124]: t3flat = nnfl(t3)
# note that the result is 2D, as opposed to 1D with torch.flatten
In [125]: t3flat
Out[125]:
tensor([[12, 13, 14, 15, 16, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26, 27],
[28, 29, 30, 31, 32, 33, 34, 35]])
# modification to the result
In [126]: t3flat[-1, -1] = -1
# input tensor also modified. Thus, flattened result is a view.
In [127]: t3
Out[127]:
tensor([[[12, 13, 14, 15],
[16, 17, 18, 19]],
[[20, 21, 22, 23],
[24, 25, 26, 27]],
[[28, 29, 30, 31],
[32, 33, 34, -1]]])torch.flatten()是nn.Flatten()及其弟兄们 nn.Unflatten()的前身,因为它从一开始就存在。然后,出现了一个合法的nn.Flatten(),因为这是几乎所有ConvNets的共同要求(就在softmax之前或其他地方)。因此,后来在PR #22245中添加了它。
最近也有用于模型手术的在ResNets中。
https://stackoverflow.com/questions/65993494
复制相似问题