首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >TopK算子在vectorized_map中的应用

TopK算子在vectorized_map中的应用
EN

Stack Overflow用户
提问于 2020-06-19 19:55:01
回答 1查看 497关注 0票数 0

我试图对形状为tf.vectorized_map批号、listsize和对批处理中的每一行应用tf.math.top_k操作符进行操作,但没有成功。

例如,数据可以是:

代码语言:javascript
复制
[ [1,2,4,5,6], [9,5,4,2,1] ]

我想在[1,2,4,5,6][9,5,4,2,1]上应用topk。

然而,我成功地使用tf.map_fn做了同样的事情,但是vectorized_map应该运行得更快。我使用tensorflow 1.15

代码语言:javascript
复制
import tensorflow as tf
import numpy as np

# create fake data
x = tf.convert_to_tensor([
    [1,2,4,5,6],
    [9,5,4,2,1],
], dtype=tf.float32)
x = tf.reshape(x, (2, -1))

B = x.shape[0] # batchsize
L = x.shape[1] # list size

print(f"B {B}, L {L}")

sess = tf.Session()

print(f"x tensor: {sess.run(x)}\n")

def fv(_x):
    #_tensor = tf.reshape(_x, (L,))  # doesnt work (1)
    _tensor = tf.reshape(tf.convert_to_tensor([9,5,4,2,1], dtype=tf.float32), (L,)) # work (2)
    #_tensor = tf.convert_to_tensor([9,5,4,2,1], dtype=tf.float32) # work (3)

    print(f"_tensor: {_tensor}")
    values, indices = tf.math.top_k(_tensor, k=3)
    # i just need the indices
    return indices

indices = tf.vectorized_map(
        fv,
        x,
)
print("\nindices ")
print(sess.run(indices))

正如我们所看到的,(2)和(3)运行,所以topk操作符应该是可用的。此外,即使(1)不能工作,我也可以使用_x,例如,返回如下:

代码语言:javascript
复制
def fv(_x):
    return _x * 10

因此,_x是可用的。

因此,当我使用(1)运行代码时,会出现以下错误:

代码语言:javascript
复制
ValueError: No converter defined for TopKV2
name: "loop_body/TopKV2"
op: "TopKV2"
input: "loop_body/Reshape"
input: "loop_body/TopKV2/k"
attr {
  key: "T"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "sorted"
  value {
    b: true
  }
}

inputs: [WrappedTensor(t=<tf.Tensor 'loop_body/Reshape/pfor/Reshape:0' shape=(2, 5) dtype=float32>, is_stacked=True, is_sparse_stacked=False), WrappedTensor(t=<tf.Tensor 'loop_body/TopKV2/k:0' shape=() dtype=int32>, is_stacked=False, is_sparse_stacked=False)]. 
Either add a converter or set --op_conversion_fallback_to_while_loop=True, which may run slower

Process finished with exit code 1

在这里,我只想得到索引,因为我需要处理向量,以便在K=3的输出中具有类似于K=3的值(如果值在topk 0中)。并给出另一个形状批次的张量,1包含每一行的K值。(我已经成功地用map_fn做了这件事,所以我认为以后不会有问题)。

也许可以在矢量化地图中实现我自己的topk操作符,但我宁愿不这样做。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2020-06-20 21:38:56

最后,我做了这样的事情:这不使用vectorized_map,但这是我想要做的。但是如果有人能让它与vectorized_map一起工作,我会考虑解决方案。:)

代码语言:javascript
复制
def topk(x, k):
    """
    x : shape [B, L]
    k : shape [B, 1]
    return : final_mask of shape [B,L] with final_mask[b,i] = 0 if x[b,i] is in  
    the k[b] biggest values of x[b,:], else final_mask[b,i] = 1
    """
    B = x.shape[0]  # batchsize
    L = x.shape[1]  # list size

    # the indices sorted in descending order
    indices_des = tf.argsort(x, axis=-1, direction='DESCENDING', stable=False, name='sorting_for_topk')

    
    mask = tf.reshape(tf.range(start=0, limit=L, dtype=tf.int32), [1, L])
    mask = tf.repeat(mask, [B], axis=0)
    mask = mask<k

    one_hot = tf.one_hot(indices_des, depth=L) * tf.cast(tf.reshape(mask, [B, L, 1]), tf.float32)
    final_mask = tf.reduce_sum(one_hot, axis=1)
    
    return final_mask
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62477677

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档