首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >tensorflow中的stop_gradient

tensorflow中的stop_gradient
EN

Stack Overflow用户
提问于 2018-05-08 04:15:37
回答 1查看 1.9K关注 0票数 2

我想知道tf.stop_gradient是只停止给定op的梯度计算,还是停止更新其输入tf.variable?我有以下问题-在MNIST的前向路径计算期间,我想对权重执行一组操作(比方说W到W*),然后使用输入进行matmul。但是,我想从反向路径中排除这些操作。我只需要在反向传播的训练过程中计算dE/dW。我写的代码阻止了W的更新。你能告诉我为什么吗?如果这些是变量,我理解我应该将它们的trainable属性设置为false,但这些是对权重的操作。如果stop_gradient不能用于此目的,那么我如何构建两个图,一个用于前向路径,另一个用于反向传播?

代码语言:javascript
复制
def build_layer(inputs, fmap, nscope,layer_size1,layer_size2, faulty_training):  
  with tf.name_scope(nscope): 
    if (faulty_training):
      ## trainable weight
      weights_i = tf.Variable(tf.truncated_normal([layer_size1, layer_size2],stddev=1.0 / math.sqrt(float(layer_size1))),name='weights_i')
      ## Operations on weight whose gradient should not be computed during backpropagation
      weights_fx_t = tf.multiply(268435456.0,weights_i)
      weight_fx_t = tf.stop_gradient(weights_fx_t)
      weights_fx = tf.cast(weights_fx_t,tf.int32)
      weight_fx = tf.stop_gradient(weights_fx)
      weights_fx_fault = tf.bitwise.bitwise_xor(weights_fx,fmap)
      weight_fx_fault = tf.stop_gradient(weights_fx_fault)
      weights_fl = tf.cast(weights_fx_fault, tf.float32)
      weight_fl = tf.stop_gradient(weights_fl)
      weights = tf.stop_gradient(tf.multiply((1.0/268435456.0),weights_fl))
      ##### end transformation
    else:
      weights = tf.Variable(tf.truncated_normal([layer_size1, layer_size2],stddev=1.0 / math.sqrt(float(layer_size1))),name='weights')


    biases = tf.Variable(tf.zeros([layer_size2]), name='biases')
    hidden = tf.nn.relu(tf.matmul(inputs, weights) + biases)
    return weights,hidden

我正在使用tensorflow梯度下降优化器进行训练。

代码语言:javascript
复制
optimizer = tf.train.GradientDescentOptimizer(learning_rate) 
global_step = tf.Variable(0, name='global_step', trainable=False) 
train_op = optimizer.minimize(loss, global_step=global_step)
EN

回答 1

Stack Overflow用户

发布于 2018-05-08 05:29:27

停止梯度将阻止反向传播继续通过图形中的该节点。你的代码没有任何从weights_i到损失的路径,除了通过weights_fx_t的路径,在那里梯度停止。这就是导致weights_i在训练期间不能更新的原因。你不需要在每一步之后都放上stop_gradient。只使用一次就会停止反向传播。

如果stop_gradient不能完成您想要的操作,那么您可以通过执行tf.gradients来获得渐变,并且可以使用tf.assign编写您自己的更新操作。这将允许您随心所欲地更改渐变。

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/50221783

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档