首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >支持XLA的动态切片

支持XLA的动态切片
EN

Stack Overflow用户
提问于 2020-05-29 19:04:47
回答 1查看 195关注 0票数 0

在XLA编译的函数中,有没有办法根据随机数生成器对张量进行动态切片?例如:

代码语言:javascript
复制
@tf.function(experimental_compile=True)
def random_slice(input, max_slice_size):
    offset = tf.squeeze(tf.random.uniform([1], minval=0, maxval=input.shape[0]-max_slice_size, dtype=tf.int32))
    sz = tf.squeeze(tf.random.uniform([1], minval=1, maxval=max_slice_size, dtype=tf.int32))

    indices = tf.range(offset, offset+sz)  # Non-XLA-able due to non-static bounds

    return tf.gather(input, indices)

x = tf.ones([50, 50])
y = random_slice(x, 4)

这段代码无法编译,因为XLA要求tf.range的参数在编译时是已知的。是否有推荐的解决方法?

EN

回答 1

Stack Overflow用户

发布于 2020-05-30 05:31:16

这里的基本问题是,XLA需要静态地知道程序中所有Tensor的形状。在这种情况下,它抱怨tf.range,因为在给定随机输入的情况下,它的输出是未知的。相反,您可能能够生成一个屏蔽版本(将不需要的元素清零,使用类似tensor_scatter_nd_update的东西),并在下游使用该屏蔽版本(很难确切地说出如何使用,因为没有看到关于如何使用y的更多上下文)。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/62084464

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档