首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何从haiku中的params (pytree)中获取参数?(jax框架)

如何从haiku中的params (pytree)中获取参数?(jax框架)
EN

Stack Overflow用户
提问于 2021-09-02 14:10:28
回答 1查看 58关注 0票数 1

例如,您设置了一个具有params的模块。但是,如果你想在亏损的情况下正规化一些东西,模式是什么?

代码语言:javascript
复制
import jax.numpy as jnp
import jax
def loss(params, x, y):
   l = jnp.sum((y - mlp.apply(params, x)) ** 2)
   w = hk.get_params(params, 'w') # does not work like this
   l += jnp.sum(w ** w)
   return l

示例中缺少一些模式。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-09-03 17:36:29

params本质上是一个只读字典,因此您可以通过将其视为字典来获取参数的值:

代码语言:javascript
复制
print(params['w'])

如果您想要更新参数,您不能就地更新,但必须首先将其转换为可变字典:

代码语言:javascript
复制
params_mutable = hk.data_structures.to_mutable_dict(params)
params_mutable['w'] = 3.14
params_new = hk.data_structures.to_immutable_dict(params_mutable)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/69031947

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档