首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何在tensorflow_probability层上计算渐变?

如何在tensorflow_probability层上计算渐变?
EN

Stack Overflow用户
提问于 2021-03-12 09:26:08
回答 1查看 69关注 0票数 1

我想用tf.GradientTape()来计算tensorflow_probability层的渐变。这是相当简单的使用正常的,例如,密集层

代码语言:javascript
复制
inp = tf.random.normal((2,5))
layer = tf.keras.layers.Dense(10)

with tf.GradientTape() as tape:
    out = layer(inp)
    loss = tf.reduce_mean(1-out)
grads = tape.gradient(loss, layer.trainable_variables)
print(grads)
[<tf.Tensor: shape=(5, 10), dtype=float32, numpy=
 array([[ 0.04086879,  0.04086879, -0.02974391,  0.04086879,  0.04086879,
          0.04086879, -0.02974391,  0.04086879, -0.02974391, -0.07061271],
        [ 0.01167339,  0.01167339, -0.02681615,  0.01167339,  0.01167339,
          0.01167339, -0.02681615,  0.01167339, -0.02681615, -0.03848954],
        [ 0.00476769,  0.00476769, -0.00492069,  0.00476769,  0.00476769,
          0.00476769, -0.00492069,  0.00476769, -0.00492069, -0.00968838],
        [-0.00462376, -0.00462376,  0.05914849, -0.00462376, -0.00462376,
         -0.00462376,  0.05914849, -0.00462376,  0.05914849,  0.06377225],
        [-0.11682947, -0.11682947, -0.06357963, -0.11682947, -0.11682947,
         -0.11682947, -0.06357963, -0.11682947, -0.06357963,  0.05324984]],
       dtype=float32)>,
 <tf.Tensor: shape=(10,), dtype=float32, numpy=
 array([-0.05, -0.05, -0.1 , -0.05, -0.05, -0.05, -0.1 , -0.05, -0.1 ,
        -0.05], dtype=float32)>]

但是,如果我使用DenseReparameterization完成此操作,则grads不会注册任何内容。

代码语言:javascript
复制
inp = tf.random.normal((2,5))
layer = tfp.layers.DenseReparameterization(10)

with tf.GradientTape() as tape:
    out = layer(inp)
    loss = tf.reduce_mean(1-out)
grads = tape.gradient(loss, layer.trainable_variables)
print(grads)
[None, None, None]

谁能告诉我如何解决这个问题,以便梯度是胶带和注册?

EN

回答 1

Stack Overflow用户

发布于 2021-03-12 11:41:28

啊哈,就是这样!我使用的是tf v2.1.0。显然,这在tensorflow_probability上不能很好地工作。我会尽快升级的。谢谢你gobrewers14。

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/66592908

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档