首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >torch.flatten()和nn.Flatten()之间的区别

torch.flatten()和nn.Flatten()之间的区别
EN

Stack Overflow用户
提问于 2021-02-01 13:23:08
回答 2查看 10.8K关注 0票数 10

torch.flatten()torch.nn.Flatten()有什么区别?

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2021-02-01 14:14:25

在PyTorch中有三种形式的扁平化

  • 作为张量方法(oop样式),torch.Tensor.flatten直接应用于张量:x.flatten()
  • 作为一个函数(函数形式),torch.flatten应用于:torch.flatten(x)
  • 作为一个模块(层nn.Module) nn.Flatten()。一般用于模型定义中。

这三者都是相同的,并且共享相同的实现,唯一的区别是nn.Flatten默认将start_dim设置为1,以避免将第一个轴(通常是批处理轴)扁平化。而另外两个张量则从axis=0axis=-1 --即整个张量--如果不给出参数的话。

票数 13
EN

Stack Overflow用户

发布于 2021-02-01 20:34:45

您可以认为torch.flatten()的工作只是简单地对张量进行扁平操作,而不附加任何附加条件。你给一个张量,它变平,并返回它。到此为止了。

相反,nn.Flatten()要复杂得多(也就是说,它是一个神经网络层)。由于是面向对象的,所以它继承了nn.Module,尽管它是用来扁平张量的方法。您可以更多地将其看作是torch.flatten()上的语法糖。

重要差异:一个值得注意的区别是,如果输入至少为1D或更大,则torch.flatten()总是返回一维张量,而nn.Flatten()总是返回2D张量,条件是输入至少为2D或更大(以一维张量作为输入,则会抛出IndexError)。

比较:

  • torch.flatten()是一个API,而nn.Flatten()是一个神经网络层。
  • torch.flatten()是python函数,而nn.Flatten()是python类。
  • 由于以上所述,有很多方法和属性
  • torch.flatten()可以在野外使用(例如,用于简单张量操作),而nn.Flatten()则被期望作为层之一在nn.Sequential()块中使用。
  • torch.flatten()没有关于计算图的信息,除非它被卡在其他具有图形感知的块中( tensor.requires_grad标志设置为True),而nn.Flatten()总是被自动梯度跟踪。
  • torch.flatten()不能接受和处理(例如线性/卷积)层作为输入,而nn.Flatten()主要用于处理这些神经网络层。
  • torch.flatten()nn.Flatten()都将视图返回到输入张量。因此,对结果的任何修改也会影响输入张量。(见下面的代码)

代码演示

代码语言:javascript
复制
# input tensors to work with
In [109]: t1 = torch.arange(12).reshape(3, -1)
In [110]: t2 = torch.arange(12, 24).reshape(3, -1)
In [111]: t3 = torch.arange(12, 36).reshape(3, 2, -1)   # 3D tensor

torch.flatten()进行压扁

代码语言:javascript
复制
In [113]: t1flat = torch.flatten(t1)

In [114]: t1flat
Out[114]: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

# modification to the flattened tensor    
In [115]: t1flat[-1] = -1

# input tensor is also modified; thus flattening is a view.
In [116]: t1
Out[116]: 
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, -1]])

nn.Flatten()进行压扁

代码语言:javascript
复制
In [123]: nnfl = nn.Flatten()
In [124]: t3flat = nnfl(t3)

# note that the result is 2D, as opposed to 1D with torch.flatten
In [125]: t3flat
Out[125]: 
tensor([[12, 13, 14, 15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26, 27],
        [28, 29, 30, 31, 32, 33, 34, 35]])

# modification to the result
In [126]: t3flat[-1, -1] = -1

# input tensor also modified. Thus, flattened result is a view.
In [127]: t3
Out[127]: 
tensor([[[12, 13, 14, 15],
         [16, 17, 18, 19]],

        [[20, 21, 22, 23],
         [24, 25, 26, 27]],

        [[28, 29, 30, 31],
         [32, 33, 34, -1]]])

torch.flatten()nn.Flatten()及其弟兄们 nn.Unflatten()的前身,因为它从一开始就存在。然后,出现了一个合法的nn.Flatten(),因为这是几乎所有ConvNets的共同要求(就在softmax之前或其他地方)。因此,后来在PR #22245中添加了它。

最近也有用于模型手术的在ResNets中

票数 3
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/65993494

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档