首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >trax tl.Relu和tl.ShiftRight层嵌套在串行组合器中

trax tl.Relu和tl.ShiftRight层嵌套在串行组合器中
EN

Stack Overflow用户
提问于 2021-06-29 18:46:46
回答 1查看 83关注 0票数 3

我正在尝试构建一个注意力模型,但默认情况下,Relu和ShiftRight层嵌套在串行组合器中。这进一步给我带来了训练中的错误。

代码语言:javascript
复制
layer_block = tl.Serial(
    tl.Relu(),
    tl.LayerNorm(), )

x = np.array([[-2, -1, 0, 1, 2],
              [-20, -10, 0, 10, 20]]).astype(np.float32) 

layer_block.init(shapes.signature(x)) y = layer_block(x)

print(f'layer_block: {layer_block}')

输出

代码语言:javascript
复制
layer_block: Serial[
  Serial[
    Relu
  ]
  LayerNorm
]

预期输出

代码语言:javascript
复制
layer_block: Serial[
  Relu
  LayerNorm
]

tl.ShiftRight()也出现了同样的问题

上面的代码取自官方文档Example 5

提前感谢

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2021-07-05 04:16:55

我找不到上述问题的确切解决方案,但您可以使用tl.Fn()创建一个自定义函数,并在其中添加ReluShiftRight函数代码。

代码语言:javascript
复制
def _zero_pad(x, pad, axis):
    """Helper for jnp.pad with 0s for single-axis case."""
    pad_widths = [(0, 0)] * len(x.shape)
    pad_widths[axis] = pad  # Padding on axis.
    
    return jnp.pad(x, pad_widths, mode='constant')


def f(x):
    if mode == 'predict':
        return x
    padded = _zero_pad(x, (n_positions, 0), 1)
    return padded[:, :-n_positions]

# set ShiftRight parameters as global 
n_positions = 1
mode='train'

layer_block = tl.Serial(
    tl.Fn('Relu', lambda x: jnp.where(x <= 0, jnp.zeros_like(x), x)),
    tl.LayerNorm(),
    tl.Fn(f'ShiftRight({n_positions})', f)
)


x = np.array([[-2, -1, 0, 1, 2],
              [-20, -10, 0, 10, 20]]).astype(np.float32)
layer_block.init(shapes.signature(x))
y = layer_block(x)


print(f'layer_block: {layer_block}')

输出

代码语言:javascript
复制
layer_block: Serial[
  Relu
  LayerNorm
  ShiftRight(1)
]
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/68177221

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档