首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >用flax.nn.Module实现神经网络

用flax.nn.Module实现神经网络
EN

Stack Overflow用户
提问于 2022-03-15 00:13:55
回答 1查看 377关注 0票数 0

我正在尝试用flax.nn.Module实现一个基本的RNN单元。实现RNN单元的公式非常简单:

a_t =W* h_{t-1} +U* x_t +b h_t = tanh(a_t) o_t =V* h_t +c

其中h_t是时间t处的更新状态,x_t是输入,o_t是输出,Tanh是我们的激活函数。

我的代码使用flax.nn.Module

代码语言:javascript
复制
class ElmanCell(nn.Module):
  @nn.compact
  def __call__(self, h, x):
    nextState = jnp.tanh(jnp.dot(W, h) * jnp.dot(U, x) + b)
    return nextState

我不知道hoe实现参数W、U和b,它们应该是nn.Module的属性吗?

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-03-19 22:15:11

试一试如下:

代码语言:javascript
复制
class RNNCell(nn.Module):
  @nn.compact
  def __call__(self, state, x):
    # Wh @ h + Wx @ x + b can be efficiently computed
    # by concatenating the vectors and then having a single dense layer
    x = np.concatenate([state, x])
    new_state = np.tanh(nn.Dense(state.shape[0])(x))
    return new_state

这样就可以了解参数。请参阅https://schmit.github.io/jax/2021/06/20/jax-language-model-rnn.html

票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71475589

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档