首页
学习
活动
专区
圈层
工具
发布
社区首页 >专栏 >JAX核心设计解析:函数式编程让代码更可控

JAX核心设计解析:函数式编程让代码更可控

作者头像
deephub
发布2025-12-24 13:35:27
发布2025-12-24 13:35:27
1350
举报
文章被收录于专栏:DeepHub IMBADeepHub IMBA

很多人刚接触JAX都会有点懵——参数为啥要单独传?随机数还要自己管key?这跟PyTorch的画风完全不一样啊。

其实根本原因就一个:JAX是函数式编程而不是面向对象那套,想明白这点很多设计就都说得通了。

先说个核心区别

PyTorch里,模型是个对象,权重藏在里面,训练的时候自己更新自己。这是典型的面向对象思路,状态封装在对象内部。

JAX的思路完全反过来。模型定义是模型定义,参数是参数,两边分得清清楚楚。函数本身不持有任何状态,每次调用都把参数从外面传进去。

这么做的好处?JAX可以把你的函数当纯数学表达式来处理。求导、编译、并行,想怎么折腾都行,因为函数里没有藏着掖着的东西,行为完全可预测。

代码对比一下就明白了

PyTorch这么写:

代码语言:javascript
复制
 importtorch  
importtorch.nnasnn  

classModel(nn.Module):  
    def__init__(self):  
        super().__init__()  
        self.linear=nn.Linear(10, 1)  

    defforward(self, x):  
        returnself.linear(x)  

model=Model()  
x=torch.randn(5, 10)  
 output=model(x)

权重在self.linear里,模型自己管自己。

JAX配Flax是这样:

代码语言:javascript
复制
 importjax  
importjax.numpyasjnp  
fromflaximportlinenasnn  

classModel(nn.Module):  
    @nn.compact  
    def__call__(self, x):  
        returnnn.Dense(1)(x)  

model=Model()  

key=jax.random.PRNGKey(0)  
dummy=jnp.ones((1, 10))  
params=model.init(key, dummy)['params']  

x=jnp.ones((5, 10))  
 output=model.apply({'params': params}, x)

参数要先init出来,用的时候再apply进去。麻烦是麻烦了点,但参数流向一目了然,想做什么骚操作都很方便。

随机数那个key是怎么回事

这个确实是JAX最让新手头疼的地方。不能直接random.normal()完事,非得带个key:

代码语言:javascript
复制
 key=jax.random.PRNGKey(42)  
 x=jax.random.normal(key, (3,))

原因还是那个——函数式编程不允许隐藏状态。

普通框架的随机数生成器内部维护一个种子状态,每次调用偷偷改一下。JAX不干这事。你得显式给它一个key,它用完就扔,下次想生成随机数再给个新的。

好处是随机性完全可控可复现。jit编译、多卡训练、梯度计算,不管代码怎么变换,只要key一样结果就一样。调试的时候不会遇到那种"明明代码没改怎么结果不一样了"的玄学问题。

key不能复用,用之前要split

还有个规矩:同一个key只能用一次。要生成多个随机数,得先split:

代码语言:javascript
复制
 key=jax.random.PRNGKey(0)  
 
 key, subkey=jax.random.split(key)  
 a=jax.random.normal(subkey)  
 
 key, subkey=jax.random.split(key)  
 b=jax.random.uniform(subkey)

每次split出来的subkey都是独立的随机源。这套机制在分布式场景下特别香,不同机器拿不同的key,随机性既独立又可追溯。

合在一起看个完整例子

代码语言:javascript
复制
 defforward(params, x):  
    w, b=params  
    returnw*x+b  

definit_params(key):  
    key_w, key_b=jax.random.split(key)  
    w=jax.random.normal(key_w)  
    b=jax.random.normal(key_b)  
    returnw, b  

key=jax.random.PRNGKey(0)  
params=init_params(key)  

x=jnp.array(2.0)  
 output=forward(params, x)

forward是纯函数,输入决定输出,没有副作用。随机性在init_params里一次性处理完。参数独立存放,想存哪存哪。

这种代码JAX处理起来特别顺手——jit编译、自动微分、vmap批处理、多卡并行,都是开箱即用。

什么场景下JAX更合适

说实话JAX学习曲线是陡了点。但有些场景下它的优势很明显:做研究需要魔改模型结构的时候;物理仿真对数值精度和可复现性要求高的时候;大规模分布式训练不想被隐藏状态坑的时候;想自己撸optimizer或者自定义layer的时候。

适应了这套显式风格之后其实挺舒服的。参数在哪、随机数哪来的、函数干了啥,全都摆在明面上。没有黑魔法,debug的时候心里有底。

作者: Ali Nawaz

本文参与 腾讯云自媒体同步曝光计划,分享自微信公众号。
原始发表:2025-12-07,如有侵权请联系 cloudcommunity@tencent.com 删除

本文分享自 DeepHub IMBA 微信公众号,前往查看

如有侵权,请联系 cloudcommunity@tencent.com 删除。

本文参与 腾讯云自媒体同步曝光计划  ,欢迎热爱写作的你一起参与!

评论
登录后参与评论
0 条评论
热度
最新
推荐阅读
目录
  • 先说个核心区别
  • 代码对比一下就明白了
  • 随机数那个key是怎么回事
  • key不能复用,用之前要split
  • 合在一起看个完整例子
  • 什么场景下JAX更合适
领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档