我有一个定义模型结构的代码
from neural_tangents import stax
from neural_tangents.stax import Dense
from jax import jit
def model(
W_std,
b_std,
width,
depth,
activation,
parameterization
):
"""Construct fully connected NN model and infinite width NTK & NNGP kernel
function.
Args:
W_std (float): Weight standard deviation.
b_std (float): Bias standard deviation.
width (int): Hidden layer width.
depth (int): Number of hidden layers.
activation (string): Activation function string, 'erf' or 'relu'.
parameterization (string): Parameterization string, 'ntk' or 'standard'.
Returns:
`(init_fn, apply_fn, kernel_fn)`
"""
act = activation_fn(activation)
layers_list = [Dense(width, W_std, b_std, parameterization=parameterization)]
def layer_block():
return stax.serial(act(), Dense(width, W_std, b_std, parameterization=parameterization))
for _ in range(depth-1):
layers_list += [layer_block()]
layers_list += [act(), Dense(1, W_std, b_std, parameterization=parameterization)]
# print (f"---- layer list is {layers_list} ------")
init_fn, apply_fn, kernel_fn = stax.serial(*layers_list)
apply_fn = jit(apply_fn)
return init_fn, apply_fn, kernel_fn我看不到在哪里可以建立输入的维度。默认情况下它是1,但我需要调整此结构以适应更高维度的输入。Dense中的width参数仅指定输出尺寸。如何更改输入尺寸?代码来自here
发布于 2020-08-18 01:57:21
关键是Dense不需要输入维度。在init_fn函数中指定:
init_fn, apply_fn, kernel_fn = model(
W_std,
b_std,
width,
depth,
activation,
parameterization
)
_, init_params = init_fn(key, input.shape)https://stackoverflow.com/questions/63434805
复制相似问题