首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何抑制LSTMStateTuple在流场中的梯度

如何抑制LSTMStateTuple在流场中的梯度
EN

Stack Overflow用户
提问于 2017-04-19 03:24:22
回答 1查看 322关注 0票数 1

我正在运行一个用于语言建模的基本lstm代码。但我不想做BPTT。我想做一些像tf.stop_gradient(state)这样的事情

代码语言:javascript
复制
with tf.variable_scope("RNN"):
  for time_step in range(N):
    if time_step > 0: tf.get_variable_scope().reuse_variables()
    (cell_output, state) = cell(inputs[:, time_step, :], state)

然而,stateLSTMStateTuple,所以我尝试:

代码语言:javascript
复制
for lli in range(len(state)):
    print(state[lli].c, state[lli].h)
    state[lli].c = tf.stop_gradient(state[lli].c)
    state[lli].h = tf.stop_gradient(state[lli].h)

但是我得到了一个AttributeError: can't set attribute错误:

代码语言:javascript
复制
File "/home/liyu-iri/IRRNNL/word-rnn/ptb/models/decoupling.py", line 182, in __init__
state[lli].c = tf.stop_gradient(state[lli].c)
AttributeError: can't set attribute

我也尝试使用tf.assign,但state[lli].c不是变量。

所以,我想知道怎样才能阻止LSTMStateTuple的梯度?或者,我怎么能阻止BPTT?我只想做单帧的BP。

非常感谢!

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-04-19 18:49:24

我认为这是一个纯粹的python问题: LSTMStateTuple只是一个collections.namedtuple,python不允许您在那里分配元素(就像在其他元组中一样)。解决方案是创建一个全新的,例如,在stopped_state = LSTMStateTuple(tf.stop_gradient(old_tuple.c), tf.stop_gradient(old_tuple.h))中,然后使用这个(或其中的一个列表)作为状态。如果您坚持替换现有的元组,我认为namedtuple有一个_replace方法,参见这里,就像在old_tuple._replace(c=tf.stop_gradient(...))中一样。希望这能帮上忙!

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/43485775

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档