我有一个JAX布尔数组,并希望打印一个与_True_s之和相结合的语句:
import jax
import jax.numpy as jnp
from jax.experimental.host_callback import id_print
@jax.jit
def overlaps_jax():
mask_cp = jnp.array([True, False, False, True, False, True, False, True, True])
id_print(jnp.sum(mask_cp))
overlaps_jax()在_True_s中有5个mask_cp;我想将其打印为:
With jax accelerator
There are 5 true bools因为这个函数是弹跳,所以我试着用id_print打印它,但是我没有。id_print(jnp.sum(mask_cp))会打印5,但是不能在字符串中使用它。我试过以下几种方法:
id_print(jnp.sum(mask_cp))
# print:
# 5
id_print("\nWith jax accelerator\nThere are " + jnp.sum(mask_cp) + " true bools\n")
# error:
# TypeError: can only concatenate str (not "DynamicJaxprTracer") to str
print("\nWith jax accelerator\nThere are {} true bools\n".format(jnp.sum(mask_cp)))
# print:
# With jax accelerator
# There are Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)> true bools如何在这段代码中打印这样的语句?
发布于 2022-03-21 21:19:49
请注意,id_print是实验性的,其API和功能可能会发生变化。尽管如此,我不认为id_print有能力添加这样的文本,但是您可以通过一个更通用的host_callback.call来实现它
import jax
import jax.numpy as jnp
from jax.experimental.host_callback import call
@jax.jit
def overlaps_jax():
mask_cp = jnp.array([True, False, False, True, False, True, False, True, True])
call(lambda x: print(f"There are {x} true bools"), jnp.sum(mask_cp))
overlaps_jax()输出是
There are 5 true boolshttps://stackoverflow.com/questions/71548823
复制相似问题