首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >为什么numpy.random.dirichlet()不接受多维数组?

为什么numpy.random.dirichlet()不接受多维数组?
EN

Stack Overflow用户
提问于 2013-04-10 09:31:59
回答 2查看 3.1K关注 0票数 4

numpy page上,他们给出了以下示例

代码语言:javascript
复制
s = np.random.dirichlet((10, 5, 3), 20)

这一切都很好;但是,如果您想从alphas的2D数组中生成随机样本,该怎么办?

代码语言:javascript
复制
alphas = np.random.randint(10, size=(20, 3))

如果你尝试

np.random.dirichlet(alphas)

np.random.dirichlet([x for x in alphas]),或者

np.random.dirichlet((x for x in alphas))

它会产生一个ValueError: object too deep for desired array。唯一有效的方法是:

代码语言:javascript
复制
y = np.empty(alphas.shape)
for i in xrange(np.alen(alphas)):
    y[i] = np.random.dirichlet(alphas[i])
    print y

对于我的代码结构,...which远非理想。为什么会这样,还有人能想出一种更“麻木”的方式来做这件事吗?

提前谢谢。

EN

回答 2

Stack Overflow用户

回答已采纳

发布于 2013-04-10 13:01:56

编写np.random.dirichlet是为了为单个Dirichlet分布生成样本。该代码是根据Gamma分布实现的,该实现可以用作向量化代码的基础,以从不同的分布中生成样本。在下面的代码中,dirichlet_sample采用一个形状为(n,k)的数组alphas,其中每一行都是狄利克雷分布的alpha向量。它返回一个同样具有形状(n,k)的数组,每一行都是来自alphas的相应分布的样本。作为脚本运行时,它使用dirichlet_samplenp.random.dirichlet生成样本,以验证它们是否生成相同的样本(最大为正常的浮点差异)。

代码语言:javascript
复制
import numpy as np


def dirichlet_sample(alphas):
    """
    Generate samples from an array of alpha distributions.
    """
    r = np.random.standard_gamma(alphas)
    return r / r.sum(-1, keepdims=True)


if __name__ == "__main__":
    alphas = 2 ** np.random.randint(0, 4, size=(6, 3))

    np.random.seed(1234)
    d1 = dirichlet_sample(alphas)
    print "dirichlet_sample:"
    print d1

    np.random.seed(1234)
    d2 = np.empty(alphas.shape)
    for k in range(len(alphas)):
        d2[k] = np.random.dirichlet(alphas[k])
    print "np.random.dirichlet:"
    print d2

    # Compare d1 and d2:
    err = np.abs(d1 - d2).max()
    print "max difference:", err

示例运行:

代码语言:javascript
复制
dirichlet_sample:
[[ 0.38980834  0.4043844   0.20580726]
 [ 0.14076375  0.26906604  0.59017021]
 [ 0.64223074  0.26099934  0.09676991]
 [ 0.21880145  0.33775249  0.44344606]
 [ 0.39879859  0.40984454  0.19135688]
 [ 0.73976425  0.21467288  0.04556287]]
np.random.dirichlet:
[[ 0.38980834  0.4043844   0.20580726]
 [ 0.14076375  0.26906604  0.59017021]
 [ 0.64223074  0.26099934  0.09676991]
 [ 0.21880145  0.33775249  0.44344606]
 [ 0.39879859  0.40984454  0.19135688]
 [ 0.73976425  0.21467288  0.04556287]]
max difference: 5.55111512313e-17
票数 4
EN

Stack Overflow用户

发布于 2013-04-10 09:57:55

我想你要找的是

代码语言:javascript
复制
y = np.array([np.random.dirichlet(x) for x in alphas])

为了你的列表理解。否则,您只是简单地传递一个python列表或元组。我认为numpy.random.dirichlet不接受您的Alpha值列表的原因是因为它没有设置为-它已经接受了一个数组,根据文档,它希望该数组的维数为k

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/15915446

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档