首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >在Tensorflow中如何替换二维张量中的行

在Tensorflow中如何替换二维张量中的行
EN

Stack Overflow用户
提问于 2020-03-05 10:46:46
回答 1查看 445关注 0票数 2

这里是我问题的一个例子。

代码语言:javascript
复制
origin_embeddings = tf.constant([[1.1,2.2,3.3], 
                                 [4.4, 5.5, 6.6], 
                                 [7.7, 8.8, 9.9],
                                 [10.10, 11.11, 12.12]]) # 4 * 3

indice_updated_embeddings= tf.constant([[1.0, 2.0, 3.0], 
                                        [4.0, 5.0, 6.0]]) #  2 * 3

indices = tf.constant([1,3], dtype=tf.int32)

new_embeddings = tf.constant([[1.1, 2.2, 3.3],
                          [1.0, 2.0, 3.0], # rows=1, indice_updated_embeddings[0]
                          [7.7, 8.8, 9.9],
                          [4.0, 5.0, 6.0] # rows=3, indice_updated_embeddings[1]
                         ])

创建一个新的嵌入矩阵,由原始嵌入的某些行(而不是索引中的)和更新的嵌入的一些行(在索引中)组成。

在Numpy中,很容易实现以下功能:

代码语言:javascript
复制
origin_embeddings[indices] = indice_updated_embeddings # in-place
new_embeddings = np.put(origin_embeddings, indices, indice_updated_embeddings) # create a new matrix

我想要像“numpy”中的“put”函数。

我尝试过tf.where方法,但这并不容易。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-03-05 10:56:07

如果您没有使用旧版本的TensorFlow,则可以使用tf.tensor_scatter_nd_update进行此操作。

代码语言:javascript
复制
import tensorflow as tf

origin_embeddings = tf.constant([[ 1.1 ,  2.2 ,  3.3 ], 
                                 [ 4.4 ,  5.5 ,  6.6 ], 
                                 [ 7.7 ,  8.8 ,  9.9 ],
                                 [10.10, 11.11, 12.12]]) # 4 * 3
indice_updated_embeddings= tf.constant([[1.0, 2.0, 3.0], 
                                        [4.0, 5.0, 6.0]]) #  2 * 3
indices = tf.constant([1, 3], dtype=tf.int32)

new_embeddings = tf.tensor_scatter_nd_update(origin_embeddings,
                                             tf.expand_dims(indices, 1),
                                             indice_updated_embeddings)
tf.print(new_embeddings)
# [[1.1 2.2 3.3]
#  [1 2 3]
#  [7.7 8.8 9.9]
#  [4 5 6]]

tf.tensor_scatter_nd_update不可用的旧版本中,您可以这样做:

代码语言:javascript
复制
import tensorflow as tf

origin_embeddings = tf.constant([[1.1,2.2,3.3], 
                                 [4.4, 5.5, 6.6], 
                                 [7.7, 8.8, 9.9],
                                 [10.10, 11.11, 12.12]]) # 4 * 3
indice_updated_embeddings= tf.constant([[1.0, 2.0, 3.0], 
                                        [4.0, 5.0, 6.0]]) #  2 * 3
indices = tf.constant([1, 3], dtype=tf.int32)

indices2 = tf.expand_dims(indices, 1)
s = tf.shape(origin_embeddings)
mask = tf.scatter_nd(indices2, tf.ones_like(indices2, tf.bool), [s[0], 1])
updates = tf.scatter_nd(indices2, indice_updated_embeddings, s)
new_embeddings = tf.where(mask, updates, origin_embeddings)
tf.print(new_embeddings)
# [[1.1 2.2 3.3]
#  [1 2 3]
#  [7.7 8.8 9.9]
#  [4 5 6]]
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/60543598

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档