在numpy中,我可以在以下内容中进行索引:
a = np.random.randn(2,2,3)
b = np.eye(2,2).astype(np.uint8)
c = np.eye(2,2).astype(np.uint8)
print(a)
print("diff")
print(a[b,c,:]),其中ab,c,:是2*2的张量。
[[[-1.01338087 0.70149058 0.55268617]
[ 2.56941124 1.12720312 -0.07219555]]
[[-0.04084548 0.17018995 2.14229567]
[-0.68017558 -0.91788125 1.1719151 ]]]
diff
[[[-0.68017558 -0.91788125 1.1719151 ]
[-1.01338087 0.70149058 0.55268617]]
[[-1.01338087 0.70149058 0.55268617]
[-0.68017558 -0.91788125 1.1719151 ]]]但是在Pytorch中,我不能像a[b,c,:]那样用同样的方式来做索引。谁知道怎么做呢。谢谢~

发布于 2018-11-18 13:24:53
PyTorch中的索引几乎类似于numpy。
a = torch.randn(2, 2, 3)
b = torch.eye(2, 2, dtype=torch.long)
c = torch.eye(2, 2, dtype=torch.long)
print(a)
print(a[b, c, :])tensor([[[ 1.2471, 1.6571, -2.0504],
[-1.7502, 0.5747, -0.3451]],
[[-0.4389, 0.4482, 0.7294],
[-1.3051, 0.6606, -0.6960]]])
tensor([[[-1.3051, 0.6606, -0.6960],
[ 1.2471, 1.6571, -2.0504]],
[[ 1.2471, 1.6571, -2.0504],
[-1.3051, 0.6606, -0.6960]]])https://stackoverflow.com/questions/53360596
复制相似问题