例如,您设置了一个具有params的模块。但是,如果你想在亏损的情况下正规化一些东西,模式是什么?
import jax.numpy as jnp
import jax
def loss(params, x, y):
l = jnp.sum((y - mlp.apply(params, x)) ** 2)
w = hk.get_params(params, 'w') # does not work like this
l += jnp.sum(w ** w)
return l示例中缺少一些模式。
发布于 2021-09-03 17:36:29
params本质上是一个只读字典,因此您可以通过将其视为字典来获取参数的值:
print(params['w'])如果您想要更新参数,您不能就地更新,但必须首先将其转换为可变字典:
params_mutable = hk.data_structures.to_mutable_dict(params)
params_mutable['w'] = 3.14
params_new = hk.data_structures.to_immutable_dict(params_mutable)https://stackoverflow.com/questions/69031947
复制相似问题