首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tf.scatter_update或tf.scatter_nd_update可以用来更新张量的列切片吗?

tf.scatter_update或tf.scatter_nd_update可以用来更新张量的列切片吗?
EN

Stack Overflow用户
提问于 2018-10-18 10:43:25
回答 2查看 1.9K关注 0票数 1

我希望实现一个函数,它接受一个变量作为输入,修改它的一些行或列,并将它们替换回原来的变量中。我可以使用tf.gather和tf.scatter_update为行片实现它,但是对于列片无法实现,因为显然tf.scatter_update只更新行片,并且没有轴特性。我不是tensorflow的专家,所以我可能错过了什么。有人能帮忙吗?

代码语言:javascript
复制
def matrix_reg(t, percent_t, beta):
    
    ''' Takes a variable tensor t as input and regularizes some of its rows.
    The number of rows to be regularized are specified by the percent_t. Returns the original tensor by updating its rows indexed by row_ind.
    
    Arguments:
        t -- input tensor
        percent_t -- percentage of the total rows
        beta -- the regularization factor
    Output:
        the regularized tensor
        '''
    row_ind = np.random.choice(int(t.shape[0]), int(percent_t*int(t.shape[0])), replace = False)
    t_ = tf.gather(t,row_ind)
    t_reg = (1+beta)*t_-beta*(tf.matmul(tf.matmul(t_,tf.transpose(t_)),t_))
    return tf.scatter_update(t, row_ind, t_reg)
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2018-10-18 11:16:37

下面是如何更新行或列的一个小演示。其思想是指定变量的行和列索引,希望更新中的每个元素都在这些变量中结束。使用tf.meshgrid很容易做到这一点。

代码语言:javascript
复制
import tensorflow as tf

var = tf.get_variable('var', [4, 3], tf.float32, initializer=tf.zeros_initializer())
updates = tf.placeholder(tf.float32, [None, None])
indices = tf.placeholder(tf.int32, [None])
# Update rows
var_update_rows = tf.scatter_update(var, indices, updates)
# Update columns
col_indices_nd = tf.stack(tf.meshgrid(tf.range(tf.shape(var)[0]), indices, indexing='ij'), axis=-1)
var_update_cols = tf.scatter_nd_update(var, col_indices_nd, updates)
init = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init)
    print('Rows updated:')
    print(sess.run(var_update_rows, feed_dict={updates: [[1, 2, 3], [4, 5, 6]], indices: [3, 1]}))
    print('Columns updated:')
    print(sess.run(var_update_cols, feed_dict={updates: [[1, 5], [2, 6], [3, 7], [4, 8]], indices: [0, 2]}))

输出:

代码语言:javascript
复制
Rows updated:
[[0. 0. 0.]
 [4. 5. 6.]
 [0. 0. 0.]
 [1. 2. 3.]]
Columns updated:
[[1. 0. 5.]
 [2. 5. 6.]
 [3. 0. 7.]
 [4. 2. 8.]]
票数 2
EN

Stack Overflow用户

发布于 2020-04-17 12:46:04

请参阅Tensorflow2文档中的tf.Variable

__getitem__( var,slice_spec ) 创建一个给定变量的片助手对象。 这允许从变量当前内容的一部分创建次张量.有关切片的详细示例,请参见tf.Tensor.getitem。 这个函数还允许分配给切片范围。这与Python中的__setitem__功能类似。但是,语法是不同的,因此用户可以捕获分配操作,以便分组或传递到sess.run()。例如, ..。

下面是一个最低限度的工作示例:

代码语言:javascript
复制
import tensorflow as tf
import numpy as np
var = tf.Variable(np.random.rand(3,3,3))
print(var)
# update the last column of the three (3x3) matrices to random integer values
# note that the update values needs to have the same shape
# as broadcasting is not supported as of TF2
var[:,:,2].assign(np.random.randint(10,size=(3,3)))
print(var)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/52872239

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档