首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Haiku & Jax权值初始化

Haiku & Jax权值初始化
EN

Stack Overflow用户
提问于 2022-02-21 08:53:24
回答 1查看 243关注 0票数 0

在Pytorch中,可以使用以下代码初始化一个层:

代码语言:javascript
复制
def init_layer(in_features, out_features):
 x = nn.Linear(in_features, out_features)
 limit = 1.0 / math.sqrt(in_features)
 x.weight = nn.Parameter(
    data=torch.distributions.uniform.Uniform(-limit, limit).sample(x.weight.shape), requires_grad=True
)
 return x

如何使用Jax & Haiku做同样的事情?

谢谢!

EN

回答 1

Stack Overflow用户

发布于 2022-11-11 21:02:16

您可以在每个层执行此操作,例如在残馀网块上:

代码语言:javascript
复制
class Residual(hk.Module):
    """The Residual block of ResNet."""
    def __init__(self, hidden_dim, use_1x1conv=False, strides=1,
                     init = hk.initializers.RandomNormal()):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.stride = strides
        self.init = init
        
        if use_1x1conv:
            self.proj = hk.Conv2D(hidden_dim, 1,
                        stride=strides, w_init=self.init, b_init=self.init)
        else:
            self.proj = None
            
    def __call__(self, x, is_training=True):
        
        y = hk.Conv2D(self.hidden_dim, 3, padding=(1,1), stride=self.stride,
                         with_bias=False, w_init=self.init, b_init=self.init)(x)
        y = hk.BatchNorm(True, True, 0.9)(y, is_training)
        y = jax.nn.gelu(y)
        
        y = hk.Conv2D(self.hidden_dim, 3, padding=(1,1), with_bias=False,
                         w_init=self.init, b_init=self.init)(y)
        y = hk.BatchNorm(True, True, 0.9)(y, is_training)
        
        
        if self.proj:
            x = self.proj(x)
        
        return jax.nn.gelu(x + y)
票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71203500

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档