首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在neural_tangent.stax框架中指定输入维度?

如何在neural_tangent.stax框架中指定输入维度?
EN

Stack Overflow用户
提问于 2020-08-16 16:35:18
回答 1查看 100关注 0票数 0

我有一个定义模型结构的代码

代码语言:javascript
复制
from neural_tangents import stax
from neural_tangents.stax import Dense
from jax import jit

def model(
    W_std,
    b_std,
    width,
    depth,
    activation,
    parameterization
):
    """Construct fully connected NN model and infinite width NTK & NNGP kernel
       function.

    Args:
        W_std (float): Weight standard deviation.
        b_std (float): Bias standard deviation.
        width (int): Hidden layer width.
        depth (int): Number of hidden layers.
        activation (string): Activation function string, 'erf' or 'relu'.
        parameterization (string): Parameterization string, 'ntk' or 'standard'.

    Returns:
        `(init_fn, apply_fn, kernel_fn)`
    """
    act = activation_fn(activation)

    layers_list = [Dense(width, W_std, b_std, parameterization=parameterization)]

    def layer_block():
        return stax.serial(act(), Dense(width, W_std, b_std, parameterization=parameterization))

    for _ in range(depth-1):
        layers_list += [layer_block()]

    layers_list += [act(), Dense(1, W_std, b_std, parameterization=parameterization)]

    # print (f"---- layer list is {layers_list} ------")

    init_fn, apply_fn, kernel_fn = stax.serial(*layers_list)

    apply_fn = jit(apply_fn)

    return init_fn, apply_fn, kernel_fn

我看不到在哪里可以建立输入的维度。默认情况下它是1,但我需要调整此结构以适应更高维度的输入。Dense中的width参数仅指定输出尺寸。如何更改输入尺寸?代码来自here

EN

回答 1

Stack Overflow用户

发布于 2020-08-18 01:57:21

关键是Dense不需要输入维度。在init_fn函数中指定:

代码语言:javascript
复制
init_fn, apply_fn, kernel_fn = model(
         W_std,
         b_std,
         width,
         depth,
         activation,
         parameterization
     )
_, init_params = init_fn(key, input.shape)
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/63434805

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档