我有这个问题。我有两个张量,一个形状(batch_size=128,height=48,宽度= 48,depth=1),应该包含索引(从0到32x32-1)和另一个形状(batch_size=128,height=32,width = 32,depth=1),其中包含我应该映射的值。在这一秒中,这个批处理中的每个矩阵都包含它自己的值。
例如,我想映射第三个“索引矩阵”和第三个“映射矩阵”,考虑到批处理范围从0到32x32的每个项目中的索引。应对批处理中的所有项目应用相同的过程。既然这些东西应该在丢失函数中完成,而且我看到我们在那里使用批处理,那么我该如何完成这个任务呢?我认为tf.gather可能会有帮助,因为我已经使用了,但在一个简单的情况下(比如常量数组),但是我不知道如何在这个复杂的情况下使用它。
编辑:
let's suppose I have:
[
[
[1,2,0,3],
[4,2,4,0],
[1,3,3,1],
[1,2,4,8]
],
[
[3,2,0,0],
[4,5,4,2],
[7,6,3,1],
[1,5,4,8]
]
] that is a (2,4,4,1) and a tensor
[
[
[0.3,0.4,0.6],
[0.9,0.2,0.5],
[0.1,0.2,0.1]
] ,
[
[0.1,0.4,0.5],
[0.8,0.1,0.6],
[0.2,0.4,0.3]
]
] that is a (2,3,3,1).
The first contains the indexes of the second.
I would like an output:
[
[
[0.4,0.6,0.3,0.9],
[0.2,0.6,0.2,0.3],
[0.4,0.9,0.9,0.4],
[0.4,0.6,0.2,0.1],
],
[
[0.8,0.5,0.1,0.1],
[0.1,0.6,0.1,0.5],
[0.4,0.2,0.8,0.4],
[0.4,0.6,0.1,0.3]
]
]因此,索引应该引用到批处理的单个项。我也应该为这个转换提供一个导数吗?
发布于 2017-05-15 15:39:45
如果我正确理解了你的问题,你会想用
output = tf.gather_nd(tensor2, indices)如果indices是形状(batch_size, 48, 48, 3)的矩阵,那么
indices[sample][i][j] = [i, row, col]其中(row, col)是要在tensor2中获取的值的坐标。它们是tensor1中给出的内容的翻译,编码为两个数字,而不是1:
(row, col) = (tensor1[i, j] / 32, tensor1[i, j] % 32)要动态创建indices,应该这样做:
batch_size = tf.shape(tensor1)[0]
i_mat = tf.transpose(tf.reshape(tf.tile(tf.range(batch_size), [48*48]),
[48, 48, batch_size]))
# i_mat should be such that i_matrix[i, j, k, l]=i
mat_32 = tf.fill(value=tf.constant(32, dtype=tf.int32), dims=[batch_size, 48, 48])
row_mat = tf.floor_div(tensor1, mat_32)
col_mat = tf.mod(tensor1, mat_32)
indices = tf.stack([i_mat, row_mat, col_mat], axis=-1)
output = tf.gather_nd(tensor2, indices)编辑2
上面的代码发生了一些变化。
上面的代码认为您的输入张量实际上是形状(batch_size, 48, 48)和(batch_size, 32, 32),而不是(batch_size, 48, 48, 1)和(batch_size, 32, 32, 1)。要纠正这种情况,例如使用
tensor1=tf.squeeze(tensor1, axis=-1)
tensor2=tf.squeeze(tensor2, axis=-1)在我上面的代码之前,
output = tf.expand_dims(tf.gather_nd(tensor2, indices), axis=-1)
tensor1= tf.expand_dims(tensor1, axis=-1)
tensor2= tf.expand_dims(tensor2, axis=-1)在最后
https://stackoverflow.com/questions/43981134
复制相似问题