我想实现一个构建邻接矩阵的代码,例如:
如果X: 0,1,2,0,1,0,那么,
A[0, 1] = 1
A[1, 2] = 1
A[2, 0] = 1
A[0, 1] = 1
A[1, 0] = 1下面的代码运行良好,但是太慢了!所以,请至少帮助我在批(第一)维度上向量化这段代码:
A = torch.zeros((3, 3, 3), dtype = torch.float)
X = torch.tensor([[0, 1, 2, 0, 1, 0], [1, 0, 0, 2, 1, 1], [0, 0, 2, 2, 1, 1]])
for a, x in zip(A, X):
for i, j in zip(x, x[1:]):
a[i, j] = 1谢谢!:)
发布于 2019-09-03 20:14:34
我非常确定有一种更简单的方法可以做到这一点,但我试图保持在torch函数调用的范围内,以确保任何渐变操作都可以被正确跟踪。
如果这不是反向传播所必需的,我强烈建议你寻找可能利用一些numpy函数的解决方案,因为我认为在这里找到合适的东西有更强的保证。但是,无需多说,这里就是我想出的解决方案。
它实际上将X向量转换为一系列与A中的位置相对应的元组条目。为此,我们需要对齐一些索引(具体地说,第一个维度仅在X中隐式给出,因为X中的第一个列表对应于A[0,:,:],第二个列表对应于A[1,:,:],依此类推。这也可能是您可以开始优化代码的地方,因为我没有找到这样一个矩阵的合理描述,因此必须想出自己的方法来创建它。
# Start by "aligning" your shifted view of X
# Essentially, take the all but the last element,
# and put it on top of all but the first element.
X_shift = torch.stack([X[:,:-1], X[:,1:]], dim=2)
# X_shift.shape: (3,5,2) in your example
# To assign this properly, we need to turn it into a "concatenated" list,
# where each entry corresponds to a 2D tuple in the respective dimension of A.
temp_tuples = X_shift.view(-1,2).transpose(0,1)
# temp_tuples.shape: (2,15) in your example. Below are the values:
tensor([[0, 1, 2, 0, 1, 1, 0, 0, 2, 1, 0, 0, 2, 2, 1],
[1, 2, 0, 1, 0, 0, 0, 2, 1, 1, 0, 2, 2, 1, 1]])
# Now we have to create a matrix do indicate the proper "first dimension index"
fix_dims = torch.repeat_interleave(torch.arange(0,3,1), len(X[0])-1, 0).unsqueeze(dim=0)
# fix_dims.shape: (1,15)
# Long story short, this creates the following vector.
tensor([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2]])
# Note that the unsqueeze is necessary to properly concatenate the two matrices:
access_tuples = tuple(torch.cat([fix_dims, temp_tuples], dim=0))
A[access_tuples] = 1这进一步假设X中的每个维度都有相同数量的元组被更改。如果不是这样,那么您必须手动创建一个fix_dims向量,其中每个增量都重复X[i]次的长度。如果它与您的示例相同,则可以安全地使用建议的解决方案。
发布于 2019-09-03 17:13:56
使X成为元组而不是张量:
A = torch.zeros((3, 3, 3), dtype = torch.float)
X = ([0, 1, 2, 0, 1, 0], [1, 0, 0, 2, 1, 1], [0, 0, 2, 2, 1, 1])
A[X] = 1例如,通过像这样转换它:A[tuple(X)]
https://stackoverflow.com/questions/57763066
复制相似问题