首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >boolean_mask逆流?

boolean_mask逆流?
EN

Stack Overflow用户
提问于 2019-10-25 19:27:01
回答 2查看 1.2K关注 0票数 3

我有一个张量,根据boolean_mask,我分裂了

代码语言:javascript
复制
with tf.Session() as sess:
    boolean_mask = tf.constant([True, False, True, False])
    foo = tf.constant([[1,2],[3,4],[5,6],[7,8]])   
    true_foo = tf.boolean_mask(foo, boolean_mask, axis=0)
    false_foo = tf.boolean_mask(foo, tf.logical_not(boolean_mask), axis=0)
    print(sess.run((true_foo, false_foo)))

产出:

代码语言:javascript
复制
(array([[1, 2],
        [5, 6]], dtype=int32), 
 array([[3, 4],
        [7, 8]], dtype=int32))

我对true_foofalse_foo做了一些操作,然后按原来的顺序将它们重新组合在一起

代码语言:javascript
复制
    true_bar = 2*true_foo
    false_bar = 3*false_foo
    bar = tf.boolean_mask_inverse(boolean_mask, true_bar, false_bar)
    print(sess.run(bar))

应产出:

代码语言:javascript
复制
array([[ 2, 4],
       [ 9,12],
       [10,12],
       [21,24]], dtype=int32)
EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2019-10-25 20:27:36

类似于您自己的解决方案,但使用了tf.scatter_nd

代码语言:javascript
复制
true_mask = tf.cast(tf.where(boolean_mask), tf.int32)
false_mask = tf.cast(tf.where(~boolean_mask), tf.int32)
t_foo = tf.scatter_nd(true_mask, true_bar, shape=tf.shape(foo))
f_foo = tf.scatter_nd(false_mask, false_bar, shape=tf.shape(foo))
res = t_foo + f_foo
# array([[ 2,  4],
#        [ 9, 12],
#        [10, 12],
#        [21, 24]], dtype=int32)

基本上,您可以将true_barfalse_bar分散到两个不同的张量中,并将它们相加在一起。

票数 1
EN

Stack Overflow用户

发布于 2019-10-25 19:47:47

这就是我目前正在做的事情,但这似乎是不必要的复杂:

代码语言:javascript
复制
def boolean_mask_inverse(boolean_mask, true_bar, false_bar):
    stacked_bar = tf.concat((true_bar, false_bar), axis=0)
    index_mapping = tf.where(boolean_mask)
    true_index_mapping = tf.where_v2(boolean_mask)[:,0]
    false_index_mapping = tf.where_v2(tf.logical_not(boolean_mask))[:,0]
    stacked_index_mapping = tf.concat((true_index_mapping, false_index_mapping), axis=0)
    basic_indices = tf.range(tf.shape(stacked_index_mapping)[0])
    inverse_index_mapping = tf.gather(basic_indices, stacked_index_mapping)
    return tf.gather(stacked_bar, inverse_index_mapping)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/58564613

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档