首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >JAX是否单独改变了JAX函数的性能?

JAX是否单独改变了JAX函数的性能?
EN

Stack Overflow用户
提问于 2021-06-25 14:16:59
回答 1查看 176关注 0票数 2

我正在学习使用JAX,我对jitvmap的使用有一些疑问,因为我无法通过阅读文档来解决这些问题。

  1. 分别对几个函数进行jit,然后对使用它们的函数进行jit,这有什么区别吗?例如,如果我有函数foo()bar()以及一个函数

@jax.jit def fooBar(x):返回foo(x) + bar(x)

如果foo()bar()已经被抛弃了,有什么区别吗?

  1. ,在我把它放进vmap之后,我应该设置一个函数吗?在上面的例子中,我应该做jax.vmap(fooBar)

还是jax.jit(jax.vmal(fooBar))

EN

回答 1

Stack Overflow用户

发布于 2021-06-25 15:19:47

在代码执行的性能方面,单独跳转函数和在外部函数处跳转一次没有区别(功能上有一个微妙的区别:jit--编译内部函数将把内容封装在xla_call原语中,但对于最终的编译和执行几乎没有什么区别)。

另一方面,当使用vmap时,不存在隐式编译。vmap(f)将在急切的模式下执行,而jit(vmap(f))将被及时编译,通常会导致更快的执行。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68132513

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档