在numpy page上,他们给出了以下示例
s = np.random.dirichlet((10, 5, 3), 20)这一切都很好;但是,如果您想从alphas的2D数组中生成随机样本,该怎么办?
alphas = np.random.randint(10, size=(20, 3))如果你尝试
np.random.dirichlet(alphas),
np.random.dirichlet([x for x in alphas]),或者
np.random.dirichlet((x for x in alphas)),
它会产生一个ValueError: object too deep for desired array。唯一有效的方法是:
y = np.empty(alphas.shape)
for i in xrange(np.alen(alphas)):
y[i] = np.random.dirichlet(alphas[i])
print y对于我的代码结构,...which远非理想。为什么会这样,还有人能想出一种更“麻木”的方式来做这件事吗?
提前谢谢。
发布于 2013-04-10 13:01:56
编写np.random.dirichlet是为了为单个Dirichlet分布生成样本。该代码是根据Gamma分布实现的,该实现可以用作向量化代码的基础,以从不同的分布中生成样本。在下面的代码中,dirichlet_sample采用一个形状为(n,k)的数组alphas,其中每一行都是狄利克雷分布的alpha向量。它返回一个同样具有形状(n,k)的数组,每一行都是来自alphas的相应分布的样本。作为脚本运行时,它使用dirichlet_sample和np.random.dirichlet生成样本,以验证它们是否生成相同的样本(最大为正常的浮点差异)。
import numpy as np
def dirichlet_sample(alphas):
"""
Generate samples from an array of alpha distributions.
"""
r = np.random.standard_gamma(alphas)
return r / r.sum(-1, keepdims=True)
if __name__ == "__main__":
alphas = 2 ** np.random.randint(0, 4, size=(6, 3))
np.random.seed(1234)
d1 = dirichlet_sample(alphas)
print "dirichlet_sample:"
print d1
np.random.seed(1234)
d2 = np.empty(alphas.shape)
for k in range(len(alphas)):
d2[k] = np.random.dirichlet(alphas[k])
print "np.random.dirichlet:"
print d2
# Compare d1 and d2:
err = np.abs(d1 - d2).max()
print "max difference:", err示例运行:
dirichlet_sample:
[[ 0.38980834 0.4043844 0.20580726]
[ 0.14076375 0.26906604 0.59017021]
[ 0.64223074 0.26099934 0.09676991]
[ 0.21880145 0.33775249 0.44344606]
[ 0.39879859 0.40984454 0.19135688]
[ 0.73976425 0.21467288 0.04556287]]
np.random.dirichlet:
[[ 0.38980834 0.4043844 0.20580726]
[ 0.14076375 0.26906604 0.59017021]
[ 0.64223074 0.26099934 0.09676991]
[ 0.21880145 0.33775249 0.44344606]
[ 0.39879859 0.40984454 0.19135688]
[ 0.73976425 0.21467288 0.04556287]]
max difference: 5.55111512313e-17发布于 2013-04-10 09:57:55
我想你要找的是
y = np.array([np.random.dirichlet(x) for x in alphas])为了你的列表理解。否则,您只是简单地传递一个python列表或元组。我认为numpy.random.dirichlet不接受您的Alpha值列表的原因是因为它没有设置为-它已经接受了一个数组,根据文档,它希望该数组的维数为k。
https://stackoverflow.com/questions/15915446
复制相似问题