首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何用JAX打印

如何用JAX打印
EN

Stack Overflow用户
提问于 2022-03-20 17:02:57
回答 1查看 6.1K关注 0票数 2

我有一个JAX布尔数组,并希望打印一个与_True_s之和相结合的语句:

代码语言:javascript
复制
import jax
import jax.numpy as jnp
from jax.experimental.host_callback import id_print

@jax.jit
def overlaps_jax():
    mask_cp = jnp.array([True, False, False, True, False, True, False, True, True])
    id_print(jnp.sum(mask_cp))

overlaps_jax()

在_True_s中有5个mask_cp;我想将其打印为:

代码语言:javascript
复制
With jax accelerator
There are 5 true bools

因为这个函数是弹跳,所以我试着用id_print打印它,但是我没有。id_print(jnp.sum(mask_cp))会打印5,但是不能在字符串中使用它。我试过以下几种方法:

代码语言:javascript
复制
id_print(jnp.sum(mask_cp))
# print:
# 5

id_print("\nWith jax accelerator\nThere are " + jnp.sum(mask_cp) + " true bools\n")
# error:
# TypeError: can only concatenate str (not "DynamicJaxprTracer") to str

print("\nWith jax accelerator\nThere are {} true bools\n".format(jnp.sum(mask_cp)))
# print:
# With jax accelerator
# There are Traced<ShapedArray(int64[])>with<DynamicJaxprTrace(level=0/1)> true bools

如何在这段代码中打印这样的语句?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-03-21 21:19:49

请注意,id_print是实验性的,其API和功能可能会发生变化。尽管如此,我不认为id_print有能力添加这样的文本,但是您可以通过一个更通用的host_callback.call来实现它

代码语言:javascript
复制
import jax
import jax.numpy as jnp
from jax.experimental.host_callback import call

@jax.jit
def overlaps_jax():
    mask_cp = jnp.array([True, False, False, True, False, True, False, True, True])
    call(lambda x: print(f"There are {x} true bools"), jnp.sum(mask_cp))

overlaps_jax()

输出是

代码语言:javascript
复制
There are 5 true bools
票数 4
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71548823

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档