首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >使用torch.stack()

使用torch.stack()
EN

Stack Overflow用户
提问于 2021-09-17 08:32:49
回答 2查看 559关注 0票数 1
代码语言:javascript
复制
t1 = torch.tensor([1,2,3])
t2 = torch.tensor([4,5,6])
t3 = torch.tensor([7,8,9])

torch.stack((t1,t2,t3),dim=1)

在实现torch.stack()时,我不能理解如何对不同的dim进行堆叠。在这里,堆叠是针对列进行的,但是我不能理解它是如何完成的。处理2-d或3-D张量会变得更加复杂。

代码语言:javascript
复制
tensor([[1, 4, 7],
        [2, 5, 8],
        [3, 6, 9]])
EN

回答 2

Stack Overflow用户

发布于 2021-09-17 10:12:29

想象一下有n张量。如果我们停留在3D中,它们对应于体积,即矩形长方体。堆叠对应于在另一个维度上合并这些n体积:此处添加了第4个维度来承载n 3D体积。此操作与拼接形成鲜明对比,在拼接中,卷将在现有维度之一上合并。因此,三维张量的连接将导致三维张量。

以下是有限尺寸(最多三维输入)的堆叠操作的可能表示:

您选择执行堆叠的位置定义了堆叠将沿着哪个新维度发生。在上面的示例中,新创建的维度是最后一个,因此“添加维度”的概念更有意义。

在下面的可视化中,我们观察到张量如何在不同的轴上堆叠。这反过来影响得到的张量形状。

  • 对于1D情况,例如,它也可以发生在第一个轴上,见下文:

使用代码:

x_1d = list(torch.empty(3,2)) #3行>>> torch.stack(x_1d,0).shape # axis=0堆叠torch.Size(3,2) >>> torch.stack(x_1d,1).shape # axis=1堆叠torch.Size(2,list

对于二维输入,

  • 也是如此:

使用代码:

x_2d = list(torch.empty(3,2,2)) #3 2x2-squares >>> torch.stack(x_2d,0).shape # axis=0 stacking torch.Size(3,2,2) >>> torch.stack(x_2d,1).shape # axis=1 stacking torch.Size(2,3,2) >>> torch.stack(x_2d,2).shape # axis=2 stacking ( 2,2,

有了这种思维状态,您可以直观地将操作扩展到n维张量。

票数 1
EN

Stack Overflow用户

发布于 2021-09-17 10:17:15

非常简单!在这个例子中,我将使用4个变量。函数torch.stack非常类似于numpy (vstack和hstack)。示例:

代码语言:javascript
复制
t1 = torch.tensor([1,2,3])
t2 = torch.tensor([4,5,6])
t3 = torch.tensor([7,8,9])
t4 = torch.tensor([10,11,12])

如果您尝试此cmd

代码语言:javascript
复制
>> torch.stack((t1,t2,t3,t4),dim=1).size()
>> torch.Size([3, 4])

如果您使用dim=0更改dim=1

代码语言:javascript
复制
>> torch.stack((t1,t2,t3,t4),dim=0).size()
>> torch.Size([4, 3])

在第一种情况下,你有一个3x4维的张量,但在最后一种情况下,你有一个4x3张量。不使用.size()测试这段代码!

代码语言:javascript
复制
>> torch.stack((t1,t2,t3,t4),dim=1)
>> tensor([[ 1,  4,  7, 10],
        [ 2,  5,  8, 11],
        [ 3,  6,  9, 12]])

>> torch.stack((t1,t2,t3,t4),dim=0)
>> tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])

祝你编码愉快!

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69220221

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档