我有一个张量,包含五个2x2矩阵-形状(1,5,2,2),张量包含5个元素-形状(5)。我想把每个2x2矩阵(前一个张量)乘以相应的值(在后一个张量中)。结果张量应为形状(1,5,2,2)。怎么做?
在运行此代码时获取以下错误
a = torch.rand(1,5,2,2)
print(a.shape)
b = torch.rand(5)
print(b.shape)
mul = a*b
RuntimeError: The size of tensor a (2) must match the size of tensor b (5) at non-singleton dimension 3发布于 2020-05-02 07:10:39
您可以使用a * b或torch.mul(a, b) ,但是必须在乘法前后使用permute(),以便具有兼容的形状:
import torch
a = torch.ones(1,5,2,2)
b = torch.rand(5)
a.shape # torch.Size([1, 5, 2, 2])
b.shape # torch.Size([5])
c = (a.permute(0,2,3,1) * b).permute(0,3,1,2)
c.shape # torch.Size([1, 5, 2, 2])
# OR #
c = torch.mul(a.permute(0,2,3,1), b).permute(0,3,1,2)
c.shape # torch.Size([1, 5, 2, 2])permute()函数按其参数的顺序变换维数。即,a.permute(0,2,3,1)的形状为torch.Size(1,2,2,5),适合于矩阵乘法的b (torch.Size(5)),因为a的最后维数等于b的第一维。在我们完成乘法后,我们再次使用permute()将其转到.理想形状的torch.Size(1,5,2,2)通过置换(0,3,1,2)。
您可以在permute()中阅读文档的相关内容。但它适用于当前形状为1、5、2、2乘以0到3的参数,并在插入参数时进行置换,这对于a.permute(0,2,3,1)来说意味着它将保留第一个维度,因为第一个参数是0,第二个参数将移动到第四个维度,因为索引1是第四个参数。由于2和3指数位于第二和第三位,所以第三和第三维将移动到第二和第三维。例如,当谈论第四维度时,它作为一个参数的表示形式是3(而不是4)。
编辑如果您想要按元素对形状为32,5,5,2,2,2和32,5的张量进行乘法(例如,使每个2x2矩阵乘以相应的值),则可以通过permute(2,3,0,1)将维数重新排列为2,2,32,5,然后通过a * b执行乘法,然后再通过permute(2,3,0,1)返回原来的形状。这里的关键是,第一个矩阵的最后一个n维数,需要与第二个矩阵的第一个n维数对齐。在我们的例子中,n=2。
希望这能有所帮助。
https://stackoverflow.com/questions/61555342
复制相似问题