在Pytorch中,可以使用以下代码初始化一个层:
def init_layer(in_features, out_features):
x = nn.Linear(in_features, out_features)
limit = 1.0 / math.sqrt(in_features)
x.weight = nn.Parameter(
data=torch.distributions.uniform.Uniform(-limit, limit).sample(x.weight.shape), requires_grad=True
)
return x如何使用Jax & Haiku做同样的事情?
谢谢!
发布于 2022-11-11 21:02:16
您可以在每个层执行此操作,例如在残馀网块上:
class Residual(hk.Module):
"""The Residual block of ResNet."""
def __init__(self, hidden_dim, use_1x1conv=False, strides=1,
init = hk.initializers.RandomNormal()):
super().__init__()
self.hidden_dim = hidden_dim
self.stride = strides
self.init = init
if use_1x1conv:
self.proj = hk.Conv2D(hidden_dim, 1,
stride=strides, w_init=self.init, b_init=self.init)
else:
self.proj = None
def __call__(self, x, is_training=True):
y = hk.Conv2D(self.hidden_dim, 3, padding=(1,1), stride=self.stride,
with_bias=False, w_init=self.init, b_init=self.init)(x)
y = hk.BatchNorm(True, True, 0.9)(y, is_training)
y = jax.nn.gelu(y)
y = hk.Conv2D(self.hidden_dim, 3, padding=(1,1), with_bias=False,
w_init=self.init, b_init=self.init)(y)
y = hk.BatchNorm(True, True, 0.9)(y, is_training)
if self.proj:
x = self.proj(x)
return jax.nn.gelu(x + y)https://stackoverflow.com/questions/71203500
复制相似问题