我正在学习使用JAX,我对jit和vmap的使用有一些疑问,因为我无法通过阅读文档来解决这些问题。
jit,然后对使用它们的函数进行jit,这有什么区别吗?例如,如果我有函数foo()和bar()以及一个函数@jax.jit def fooBar(x):返回foo(x) + bar(x)
如果foo()和bar()已经被抛弃了,有什么区别吗?
,
vmap之后,我应该设置一个函数吗?在上面的例子中,我应该做jax.vmap(fooBar)还是jax.jit(jax.vmal(fooBar))
发布于 2021-06-25 15:19:47
在代码执行的性能方面,单独跳转函数和在外部函数处跳转一次没有区别(功能上有一个微妙的区别:jit--编译内部函数将把内容封装在xla_call原语中,但对于最终的编译和执行几乎没有什么区别)。
另一方面,当使用vmap时,不存在隐式编译。vmap(f)将在急切的模式下执行,而jit(vmap(f))将被及时编译,通常会导致更快的执行。
https://stackoverflow.com/questions/68132513
复制相似问题