在JAX (如jax.lax.scan )中编译累积内存时,处理内存的最佳方法是什么?
下面是一个几何级数示例。诱惑是只根据输入的大小来识别积累,并相应地实现。
import jax.numpy as jnp
import jax.lax as lax
def calc_gp_size(size,x0,a):
scan_fun = lambda carry, i : (a*carry,)*2
xn, x = lax.scan(scan_fun,x0,None,length=size-1)
return jnp.concatenate((x0[None],x))
jax.config.update("jax_enable_x64", True)
size = jnp.array(2**26,dtype='u8')
x0, a = jnp.array([1.0,1.0+1.0e-08],dtype='f8')
jax.jit(calc_gp_size)(size,x0,a)然而,尝试使用jax.jit将不出所料地导致ConcretizationTypeError。
正确的方法是在缓冲区已经存在的地方传递一个参数。
def calc_gp_array(array,x0,a):
scan_fun = lambda carry, i : (a*carry,)*2
xn, x = lax.scan(scan_fun,x0,array)
return jnp.concatenate((x0[None],x))
array = jnp.arange(1,2**26,dtype='u8')
x0, a = jnp.array([1.0,1.0+1.0e-08],dtype='f8')
jax.jit(calc_gp_array)(array,x0,a)我担心的是,有很多分配的内存没有被使用(或者是吗?)对于此示例,是否有更有效的内存方法,或者是否以某种方式使用了分配的内存?
编辑:合并@jakevdp的注释,将函数视为main (单次调用-包括编译和排除缓存),并对其进行分析
%memit jx.jit(calc_gp_size, static_argnums=0)(size,x0,a).block_until_ready()
# peak memory: 7058.32 MiB, increment: 959.94 MiB
%memit jx.jit(calc_gp_array)(jnp.arange(1,size,dtype='u8'),x0,a).block_until_ready()
peak memory: 7850.83 MiB, increment: 1240.22 MiB
%memit jnp.cumprod(jnp.full(size, a, dtype='f8').at[0].set(x0))
peak memory: 8150.05 MiB, increment: 1539.70 MiB粒度较小的结果将需要对jit代码进行行分析(不确定如何做到这一点)。
依次初始化数组并调用jax.jit似乎节省了内存
%memit array = jnp.arange(1,size,dtype='u8'); jx.jit(calc_gp_array)(array,x0,a).block_until_ready()
# peak memory: 6711.81 MiB, increment: 613.44 MiB
%memit array = jnp.full(size, a, dtype='f8').at[0].set(x0); jnp.cumprod(array)
# peak memory: 7675.15 MiB, increment: 1064.08 MiB发布于 2022-04-28 17:01:02
如果将size参数标记为静态并传递一个可接受的值,则第一个版本将有效:
import jax
import jax.numpy as jnp
import jax.lax as lax
def calc_gp_size(size,x0,a):
scan_fun = lambda carry, i : (a*carry,)*2
xn, x = lax.scan(scan_fun,x0,None,length=size-1)
return jnp.concatenate((x0[None],x))
jax.config.update("jax_enable_x64", True)
size = 2 ** 26
x0, a = jnp.array([1.0,1.0+1.0e-08],dtype='f8')
jax.jit(calc_gp_size, static_argnums=0)(size,x0,a)
# DeviceArray([1. , 1.00000001, 1.00000002, ..., 1.95636587,
# 1.95636589, 1.95636591], dtype=float64)我认为这可能比在第二个例子中预先分配数组的内存效率略高一些,但如果这一点很重要的话,它将是值得进行基准测试的。
另外,如果您在GPU上执行这种操作,您可能会发现内置累加(如jnp.cumprod )更具有性能。我相信这或多或少相当于您基于扫描的功能:
result = jnp.cumprod(jnp.full(size, 1 + 1E-8, dtype='f8').at[0].set(1))https://stackoverflow.com/questions/72043419
复制相似问题