首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >基于偏移量向量的tensor3元素位置偏移

基于偏移量向量的tensor3元素位置偏移
EN

Stack Overflow用户
提问于 2016-08-15 03:19:34
回答 1查看 144关注 0票数 0

我有一个tensor3 (即三维数组) x

代码语言:javascript
复制
[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]

以及y向量(即一维数组),我们将其称为“偏移量”向量,因为它指定了所需的偏移量:

代码语言:javascript
复制
[2, 1]

我希望基于向量x调整y元素的位置,以便输出如下(在第二维度上执行移位):

代码语言:javascript
复制
[[[ a  b  c  d]
  [ e  f  g  h]
  [ 0  1  2  3]]

 [[ i  j  k  l]
  [12 13 14 15]
  [16 17 18 19]]]

ab,…,l可能是任意数字。

例如,一个有效的输出可以是:

代码语言:javascript
复制
[[[ 0  0  0  0]
  [ 0  0  0  0]
  [ 0  1  2  3]]

 [[ 0  0  0  0]
  [12 13 14 15]
  [16 17 18 19]]]

另一个有效的输出可以是:

代码语言:javascript
复制
[[[ 4  5  6  7]
  [ 8  9 10 11]
  [ 0  1  2  3]]

 [[20 21 22 23]
  [12 13 14 15]
  [16 17 18 19]]]

我知道函数theano.tensor.roll(x, shift, axis=None),但是shift只能接受一个标量作为输入,即它移动所有具有相同偏移量的元素。

例如,代码:

代码语言:javascript
复制
import theano.tensor
from theano import shared
import numpy as np

x = shared(np.arange(24).reshape((2,3,4)))
print('theano.tensor.roll(x, 2, axis=1).eval(): \n{0}'.
      format(theano.tensor.roll(x, 2, axis=1).eval()))

产出:

代码语言:javascript
复制
theano.tensor.roll(x, 2, axis=1).eval():
[[[ 4  5  6  7]
  [ 8  9 10 11]
  [ 0  1  2  3]]

 [[16 17 18 19]
  [20 21 22 23]
  [12 13 14 15]]]

这不是我想要的。

如何根据偏移量向量移动tensor3元素的位置?(注意,在本例中提供的代码中,为了方便起见,tensor3是一个共享变量,但在我的实际代码中,它将是一个符号变量)

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2016-08-16 16:37:32

为此,我找不到任何专用函数,所以我只使用了theano.scan

代码语言:javascript
复制
import theano
import theano.tensor

from theano import shared
import numpy as np

y = shared(np.array([2,1]))
x = shared(np.arange(24).reshape((2,3,4)))
print('x.eval():\n{0}\n'.format(x.eval()))

def shift_and_reverse_row(matrix, y):    
    '''
    Shift and reverse the matrix in the direction of the first dimension (i.e., rows)
    matrix: matrix 
    y: scalar
    '''
    new_matrix = theano.tensor.zeros_like(matrix)
    new_matrix = theano.tensor.set_subtensor(new_matrix[:y,:], matrix[y-1::-1,:])
    return new_matrix

new_x, updates = theano.scan(shift_and_reverse_row, outputs_info=None,
                             sequences=[x, y[::-1]] )
new_x = new_x[:, ::-1, :]
print('new_x.eval(): \n{0}'.format(new_x.eval()))

产出:

代码语言:javascript
复制
x.eval():
[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]]

new_x.eval():
[[[ 0  0  0  0]
  [ 0  0  0  0]
  [ 0  1  2  3]]

 [[ 0  0  0  0]
  [12 13 14 15]
  [16 17 18 19]]]
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/38948862

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档