首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >无法通过优化学习tf.contrib.distributions.MultivariateNormalDiag的参数

无法通过优化学习tf.contrib.distributions.MultivariateNormalDiag的参数
EN

Stack Overflow用户
提问于 2017-07-15 02:10:53
回答 1查看 928关注 0票数 1

工作示例:

代码语言:javascript
复制
import numpy as np
import tensorflow as tf

## construct data
np.random.seed(723888)
N,P = 50,3 # number and dimensionality of observations
Xbase = np.random.multivariate_normal(mean=np.zeros((P,)), cov=np.eye(P), size=N)

## construct model
X      = tf.placeholder(dtype=tf.float32, shape=(None, P), name='X')
mu     = tf.Variable(np.random.normal(loc=0.0, scale=0.1, size=(P,)), dtype=tf.float32, name='mu')
xDist  = tf.contrib.distributions.MultivariateNormalDiag(loc=mu, scale_diag=tf.ones(shape=(P,), dtype=tf.float32), name='xDist')
xProbs = xDist.prob(X, name='xProbs')

## prepare optimizer
eta       = 1e-3 # learning rate
loss      = -tf.reduce_mean(tf.log(xProbs), name='loss')
optimizer = tf.train.AdamOptimizer(learning_rate=eta).minimize(loss)

## launch session
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    sess.run(optimizer, feed_dict={X: Xbase})

我想在tensorflow中对多变量高斯分布的参数进行优化,如上面的示例所示。我可以成功地运行像sess.run(loss, feed_dict={X: Xbase})这样的命令,所以我已经正确地实现了这个发行版。当我尝试运行优化操作时,我得到一条奇怪的错误消息:

代码语言:javascript
复制
InvalidArgumentError: -1 is not between 0 and 3
     [[Node: gradients_1/xDist_7/xProbs/Prod_grad/InvertPermutation = InvertPermutation[T=DT_INT32, _device="/job:localhost/replica:0/task:0/cpu:0"](gradients_1/xDist_7/xProbs/Prod_grad/concat)]]

Caused by op 'gradients_1/xDist_7/xProbs/Prod_grad/InvertPermutation'

这是我不明白的。

如果我使用tf.contrib.distributions.MultivariateNormalFullCovariance而不是tf.contrib.distributions.MultivariateNormalDiag,我会得到同样的错误信息。如果优化的变量是scale_diag而不是loc,我就不会得到这个错误。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2017-07-15 05:11:15

我仍然在寻找失败的原因,但对于短期修复,进行以下更改是否有效?

代码语言:javascript
复制
xLogProbs = xDist.log_prob(X, name='xLogProbs')
loss      = -tf.reduce_mean(xLogProbs, name='loss')

注意:这实际上比tf.log(xProbs)更可取,因为它在数值上的精确度永远不会降低--有时精度要高得多。(所有tf.Distributions都是如此。)

票数 0
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/45109305

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档