我正在阅读设备上的JAX文档,其中写的是:
与jax.devices()类似,但只返回给定进程的本地设备。
在jax.devices()中,它被写成:
返回给定后端的所有设备的列表。
我不知道这些本地和非本地设备到底是什么。你能详细说明一下两者之间的区别吗?
发布于 2022-08-23 16:03:56
这在JAX的在多主机和多进程环境中使用JAX文档中进行了讨论。
进程的本地设备是进程可以直接寻址和启动计算的设备。例如,在集群上,每个主机只能在直接连接的GPU上启动计算。在Cloud上,每个主机只能在直接连接到该主机的8个TPU核心上启动计算(有关更多细节,请参阅云TPU体系结构文档)。您可以通过
jax.local_devices()看到进程的本地设备。 全局设备是跨所有进程的设备,只要每个进程在其本地设备上启动计算,计算就可以跨越进程并通过设备之间的直接通信链路执行集体操作。您可以通过jax.devices()看到所有可用的全局设备。进程的本地设备总是全局设备的子集。
https://stackoverflow.com/questions/73458553
复制相似问题