首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >关于forget_bias的问题

关于forget_bias的问题
EN

Stack Overflow用户
提问于 2017-11-04 02:55:51
回答 1查看 354关注 0票数 0

我将用BasicLSTMCell编写一个C++,我需要检查它是否正常工作。我使用tf.nn.rnn_cell.BasicLSTMCell来实现具有4个单元的forget_bias,并将forget_bias设置为1,然后使用以下代码检查LSTM的偏差:

代码语言:javascript
复制
//////////////////////////////////////////////////////////////

    with tf.variable_scope("LSTM"):
    Cell=tf.nn.rnn_cell.BasicLSTMCell(4,forget_bias=1,state_is_tuple=True)
Sessin=tf.Session()
state=Cell.zero_state(1,dtype=tf.float32)
with  tf.variable_scope("Ut_def"):
    out,D=tf.nn.dynamic_rnn(
            cell=Cell,inputs=Feed,
            initial_state=state,
            time_major=False)
Sessin.run(tf.global_variables_initializer())
#Saver.save(Sessin,"./123/Var",global_step=1)
out,D=Sessin.run([out,D],feed_dict={Feed:np.arange(8).reshape(1,2,4)})
tf.train.Saver().save(Sessin,"./123/Var",global_step=1)
trainable_vars_dict = {}
for key in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES):
    trainable_vars_dict[key.name] = Sessin.run(key)
    # Checking the names of the keys
    print(key.name)
lstm_weight_vals = trainable_vars_dict["Ut_def/RNN/BasicLSTMCell/Linear/Matrix:0"]
B=trainable_vars_dict["Ut_def/RNN/BasicLSTMCell/Linear/Bias:0"]
print(B)
/////////////////////////////////////////////////////////////

但我发现,无论我改变了forget_bias,这些偏见都是零。

有人知道这是怎么回事吗?

为了弄清楚lstm是如何工作的,我只是使用从tensorflow中提取的擦除和偏倚来获得相同的结果。当然,它们是不等同的。

代码语言:javascript
复制
w_i, w_C, w_f, w_o = np.split(lstm_weight_vals, 4, axis=1)
w_xi = w_i[:4, :]
w_hi = w_i[4:, :]
w_xC = w_C[:4, :]
w_hC = w_C[4:, :]
w_xf = w_f[:4, :]
w_hf = w_f[4:, :]
w_xo = w_o[4:, :]
w_ho = w_o[4:, :]
Input=tf.range(4,dtype=tf.float32)
Input=tf.reshape(Input,shape=[1,4])
i=tf.sigmoid(tf.matmul(tf.zeros(shape=[1,4]),w_xi)+tf.matmul(Input,w_hi))
o=tf.sigmoid(tf.matmul(tf.zeros(shape=[1,4]),w_xo)+tf.matmul(Input,w_ho))
g=tf.tanh(tf.matmul(tf.zeros(shape=[1,4]),w_xC)+tf.matmul(Input,w_hC))
f=tf.sigmoid(tf.matmul(tf.zeros(shape=[1,4]),w_xf)+tf.matmul(Input,w_hf))
Cstate=tf.zeros(shape=[1,4])*f+i*g
Hstate=tf.tanh(Cstate)*o
Input=Input+4
i=tf.sigmoid(tf.matmul(Cstate,w_xi)+tf.matmul(Input,w_hi))
o=tf.sigmoid(tf.matmul(Cstate,w_xo)+tf.matmul(Input,w_ho))
g=tf.tanh(tf.matmul(Cstate,w_xC)+tf.matmul(Input,w_hC))
f=tf.sigmoid(tf.matmul(Cstate,w_xf)+tf.matmul(Input,w_hf))
Cstate=Cstate*f+i*g

Hstate=tf.tanh(Cstate)*o
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-11-04 10:16:21

我发现了错误的密码。代码应该如下:

代码语言:javascript
复制
i=tf.sigmoid(tf.matmul(Hstate,w_xi)+tf.matmul(Input,w_hi))
o=tf.sigmoid(tf.matmul(Hstate,w_xo)+tf.matmul(Input,w_ho))
g=tf.tanh(tf.matmul(Hstate,w_xC)+tf.matmul(Input,w_hC))
f=tf.sigmoid(tf.matmul(Hstate,w_xf)+tf.matmul(Input,w_hf)+1)

它是Hstate而不是Csatae

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/47106918

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档