首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >正态分布上N个似然数的最快计算

正态分布上N个似然数的最快计算
EN

Code Review用户
提问于 2014-11-13 11:22:50
回答 1查看 4.4K关注 0票数 4

吉布斯取样器的上下文中,我分析了我的代码,我的主要瓶颈是:

我需要计算N个点的可能性,假设它们是从N个正态分布(具有不同的均值,但方差相同)得出的。

以下是计算它的两种方法:

代码语言:javascript
复制
import numpy as np
from scipy.stats import multivariate_normal
from scipy.stats import norm

# Toy data
y = np.random.uniform(low=-1, high=1, size=100) # data points
loc = np.zeros(len(y)) # means

# Two alternatives
%timeit multivariate_normal.logpdf(y, mean=loc, cov=1)
%timeit sum(norm.logpdf(y, loc=loc, scale=1))
  • 第一种方法是:使用最近实现的multivariate_normal。构造等效的N-dimensional高斯,并计算N-dimensional y的(Log)概率。1000圈,最佳每环3: 1.33毫秒
  • 第二部分:使用传统的norm函数。计算每个点y的单个(对数)概率,然后对结果进行求和。10000圈,最佳3: 130 S每圈

由于这是吉布斯取样器的一部分,我需要重复这个计算大约10.000次,因此我需要它尽可能快。

我该如何改进呢?

(来自python或调用Cython、R或其他任何东西)

EN

回答 1

Code Review用户

回答已采纳

发布于 2015-08-03 13:04:57

您应该使用行分析器工具来检查代码中最慢的部分。这听起来好像是针对您自己的代码做的,但是您可以继续分析和分析NumPy和SciPy在计算感兴趣的数量时使用的源代码。[Line profiler](https://pypi.python.org/pypi/line_profiler/)模块是我最喜欢的。

代码语言:javascript
复制
import numpy as np
from scipy.stats import multivariate_normal
from scipy.stats import norm
%lprun -f norm.logpdf norm.logpdf(x=np.random.random(1000000), \
                                  loc=np.random.random(1000000), \
                                  scale = np.random.random())
代码语言:javascript
复制
Timer unit: 1e-06 s

Total time: 0.14831 s
File: /opt/local/Library/Frameworks/Python.framework/Versions/2.7/lib/python2.7/site-packages/scipy/stats/_distn_infrastructure.py
Function: logpdf at line 1578

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
  1578                                               def logpdf(self, x, *args, **kwds):
  1579                                                   """
  1580                                                   Log of the probability density function at x of the given RV.
  1581                                           
  1582                                                   This uses a more numerically accurate calculation if available.
  1583                                           
  1584                                                   Parameters
  1585                                                   ----------
  1586                                                   x : array_like
  1587                                                       quantiles
  1588                                                   arg1, arg2, arg3,... : array_like
  1589                                                       The shape parameter(s) for the distribution (see docstring of the
  1590                                                       instance object for more information)
  1591                                                   loc : array_like, optional
  1592                                                       location parameter (default=0)
  1593                                                   scale : array_like, optional
  1594                                                       scale parameter (default=1)
  1595                                           
  1596                                                   Returns
  1597                                                   -------
  1598                                                   logpdf : array_like
  1599                                                       Log of the probability density function evaluated at x
  1600                                           
  1601                                                   """
  1602         1           14     14.0      0.0          args, loc, scale = self._parse_args(*args, **kwds)
  1603         1           23     23.0      0.0          x, loc, scale = map(asarray, (x, loc, scale))
  1604         1            2      2.0      0.0          args = tuple(map(asarray, args))
  1605         1        13706  13706.0      9.2          x = asarray((x-loc)*1.0/scale)
  1606         1           33     33.0      0.0          cond0 = self._argcheck(*args) & (scale > 0)
  1607         1         5331   5331.0      3.6          cond1 = (scale > 0) & (x >= self.a) & (x <= self.b)
  1608         1         5625   5625.0      3.8          cond = cond0 & cond1
  1609         1           84     84.0      0.1          output = empty(shape(cond), 'd')
  1610         1         6029   6029.0      4.1          output.fill(NINF)
  1611         1        11459  11459.0      7.7          putmask(output, (1-cond0)+np.isnan(x), self.badvalue)
  1612         1         1093   1093.0      0.7          if any(cond):
  1613         1        58499  58499.0     39.4              goodargs = argsreduce(cond, *((x,)+args+(scale,)))
  1614         1            6      6.0      0.0              scale, goodargs = goodargs[-1], goodargs[:-1]
  1615         1        46401  46401.0     31.3              place(output, cond, self._logpdf(*goodargs) - log(scale))
  1616         1            4      4.0      0.0          if output.ndim == 0:
  1617                                                       return output[()]
  1618         1            1      1.0      0.0          return output

看起来,在检查和从函数输入中删除无效参数时,花费的时间并不少。如果您可以确定您永远不需要使用该特性,只需编写自己的函数来计算logpdf

另外,如果你要乘以概率(即加对数概率),你可以用代数来简化和分解正态分布的pdf的求和中的公共项。这将减少对np.log等函数调用的次数。我仓促地做了这件事,所以我可能犯了一个数学错误,但是:

代码语言:javascript
复制
def my_logpdf_sum(x, loc, scale):
    root2 = np.sqrt(2)
    root2pi = np.sqrt(2*np.pi)
    prefactor = - x.size * np.log(scale * root2pi)
    summand = -np.square((x - loc)/(root2 * scale))                         
    return  prefactor + summand.sum()



# toy data
y = np.random.uniform(low=-1, high=1, size=1000) # data points
loc = np.zeros(y.shape)
​
# timing
%timeit multivariate_normal.logpdf(y, mean=loc, cov=1)
%timeit np.sum(norm.logpdf(y, loc=loc, scale=1))
%timeit my_logpdf_sum(y, loc, 1)
1 loops, best of 3: 156 ms per loop
10000 loops, best of 3: 125 µs per loop
The slowest run took 4.55 times longer than the fastest. This could mean that an intermediate result is being cached 
100000 loops, best of 3: 16.3 µs per loop
票数 4
EN
页面原文内容由Code Review提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://codereview.stackexchange.com/questions/69718

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档