首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Cython中自适应拒绝采样的实现

Cython中自适应拒绝采样的实现
EN

Code Review用户
提问于 2017-09-04 19:25:23
回答 1查看 292关注 0票数 6

我试图在Cython中重写一些Fortran码,这是自适应拒绝抽样方法的实现。我想最终在一些Python代码中使用我的Cython版本。我想知道我是否正确地转换了代码。是否有更有效的方法来实现这段代码?

AdaptiveRejectionSampling.pyx

代码语言:javascript
复制
import cython
import numpy as np
import ctypes
cimport numpy as np
cdef extern from "math.h":
     cpdef double log(double x)
     cpdef double exp(double x)

cdef extern from "stdlib.h":
     cpdef int rand()
     cpdef enum: RAND_MAX

from libc.math cimport fabs    

ctypedef double (*f_type)(double)

######UPDATE#######
ctypedef double (*f_type)(double) nogil
ctypedef struct Data:
         double* x
         double* hx
         double* hpx

ctypedef struct Bounds:
         bint lb
         double xlb
         bint ub
         double xub
         int ifault
#########END UPDATE ######        
cdef void initial(int ns, int m, double emax, np.ndarray[double, mode="c", ndim=1] x, np.ndarray[double, mode="c", ndim=1] hx, np.ndarray[double, mode="c", ndim=1] hpx, bint lb, double xlb, bint ub, double xub, int ifault, np.ndarray[int, mode="c", ndim=1] iwv, np.ndarray[double, mode="c", ndim=1] rwv):
      """                                                                       
      This subroutine takes as input the number of starting values m        
      and the starting values x(i),hx(i),hpx(i)  i=1,m                     
      As output we have pointer iipt along with ilow and ihigh and the lower 
      and upper hulls defined  by z,hz,scum,cu,hulb,huub stored in working 
      vectors iwv and rwv                   
      Ifault detects wrong starting points or non-concavity   
      ifault codes, subroutine initial
      0:successful initialisation
      1:not enough starting points
      2:ns is less than m
      3:no abscissae to left of mode (if lb = false)
      4:no abscissae to right of mode (if ub = false)
      5:non-log-concavity detect
      """                                                               
      cdef int nn, ilow,ihigh,i
      cdef int iipt,iz,ihuz,iscum,ix,ihx,ihpx
      cdef bint horiz
      cdef double hulb,huub,eps,cu,alcu,huzmax
      """
      DESCRIPTION OF PARAMETERS and place of storage

      lb   iwv[4] : boolean indicating if there is a lower bound to the
                     domain
      ub   iwv[5] : boolean indicating if there is an upper bound
      xlb  rwv[7] : value of the lower bound
      xub  rwv[8] : value of the upper bound
      emax rwv[2] : large value for which it is possible to compute
                    an exponential, eps=exp(-emax) is taken as a small
                    value used to test for numerical unstability
      m    iwv[3] : number of starting points
      ns   iwv[2] : maximum number of points defining the hulls
      x    rwv(ix+1)  : vector containing the abscissae of the starting
                        points
      hx   rwv(ihx+1) : vector containing the ordinates 
      hpx  rwv(ihpx+1): vector containing the derivatives
      ifault      : diagnostic
      iwv,rwv     : integer and real working vectors
      """                                                                       
      eps=expon(-emax,emax)
      ifault=0                                                          
      ilow=0                                                            
      ihigh=0                                                           
      nn=ns+1
      #at least one starting point
      if (m<1):
         ifault=1

      huzmax=hx[0]                                                      
      if not ub:
         xub=0.0

      if not lb: 
         xlb=0.0

      hulb=(xlb-x[0])*hpx[0]+hx[0]                                      
      huub=(xub-x[0])*hpx[0]+hx[0]                                      
      #if bounded on both sides                                              
      if (ub and lb):                               
         huzmax=max(huub,hulb)                                         
         horiz=(fabs(hpx[0])<eps)
         if (horiz):
           cu=expon((huub+hulb)*0.5-huzmax,emax)*(xub-xlb)
         else:
           cu=expon(huub-huzmax,emax)*(1-expon(hulb-huub,emax))/hpx[0]
      elif ((ub) and (not lb)):                           
         #if bounded on the right and unbounded on the left 
         huzmax=huub                                                     
         cu=1.0/hpx[0]

      elif ((not ub) and (lb)):                           
         #if bounded on the left and unbounded on the right                     
         huzmax=hulb                                                     
         cu=-1.0/hpx[0]

         #if unbounded at least 2 starting points
      else:
         cu=0.0
         if (m<2):
             ifault=1

      if (cu>0.0):
          alcu=log(cu)
      #set pointers 
      iipt=5
      iz=8
      ihuz=nn+iz
      iscum=nn+ihuz
      ix=nn+iscum
      ihx=nn+ix
      ihpx=nn+ihx
      #store values in working vectors
      iwv[0] = ilow
      iwv[1] = ihigh
      iwv[2] = ns
      iwv[3] = 1
      if lb:
         iwv[4]=1
      else:
         iwv[4]=0

      if ub:
         iwv[5]=1
      else:
         iwv[5]=0

      if ( ns < m): 
         ifault=2

      iwv[iipt+1]=0                                                     
      rwv[0] = hulb
      rwv[1] = huub
      rwv[2] = emax
      rwv[3] = eps
      rwv[4] = cu
      rwv[5] = alcu
      rwv[6] = huzmax
      rwv[7] = xlb
      rwv[8] = xub
      rwv[iscum+1]=1.0
      for i in range(m):
         rwv[ix+i]=x[i]
         rwv[ihx+i]=hx[i]
         rwv[ihpx+i]=hpx[i]
      #create lower and upper hulls                                          
      i=0                                                               
      while (i < m):                                                   
            update(iwv[3],iwv[0],iwv[1],iwv[iipt+1],rwv[iscum+1],rwv[4],rwv[ix+1],rwv[ihx+1],rwv[ihpx+1],rwv[iz+1],rwv[ihuz+1],rwv[6],rwv[2],lb,rwv[7],rwv[0],ub,rwv[8],rwv[1],ifault,rwv[3],rwv[5]) 
            i=iwv[3]
            if (ifault!=0):
               return

      #test for wrong starting points              
      if ((not lb) and (hpx[iwv[0]]<eps)):
         ifault=3
      if ((not ub) and (hpx[iwv[1]]>-eps)):
         ifault=4
      return


cdef void sample(np.ndarray[int, mode="c", ndim=1] iwv, np.ndarray[double, mode="c", ndim=1] rwv, f_type h, f_type hprima, double beta, int ifault):
      """
      ifault
      0:successful sampling
      5:non-concavity detected
      6:random number generator generated zero
      7:numerical instability
      """      
      cdef int iipt,iz,ns,nn,ihuz,iscum,ix,ihx,ihpx
      cdef bint ub,lb

      #set pointers
      iipt=5
      iz=8
      ns=iwv[2]
      nn=ns+1
      ihuz=nn+iz
      iscum=nn+ihuz
      ix=nn+iscum
      ihx=nn+ix
      ihpx=nn+ihx
      lb=False
      ub=False
      if (iwv[4]==1):
         lb=True
      if (iwv[5]==1):
         ub=True

      #call sampling subroutine
      spl1(ns,iwv[3],iwv[0],iwv[1],iwv[iipt+1],rwv[iscum+1],rwv[4],rwv[ix+1],rwv[ihx+1],rwv[ihpx+1],rwv[iz+1],rwv[ihuz+1],rwv[6],lb,rwv[7],rwv[0],ub,rwv[8],rwv[1], h, hprima,beta,ifault,rwv[2],rwv[3],rwv[5])    
      return

cdef void spl1(int ns, int n, int ilow, int ihigh, np.ndarray[int, mode="c", ndim=1] ipt, np.ndarray[double, mode="c", ndim=1] scum, double cu, np.ndarray[double, mode="c", ndim=1] x, np.ndarray[double, mode="c", ndim=1] hx, np.ndarray[double, mode="c", ndim=1] hpx, np.ndarray[double, mode="c", ndim=1] z, np.ndarray[double, mode="c", ndim=1] huz, double huzmax, bint lb, double xlb, double hulb, bint ub, double xub, double huub, f_type h, f_type hprima, double beta, int ifault, double emax, double eps, double alcu):  
     """                                                                       
     this subroutine performs the adaptive rejection sampling, it calls 
     subroutine splhull to sample from the upper hull ,if the sampling 
     involves a function evaluation it calls the updating subroutine
     ifault is a diagnostic of any problem: non concavity, 0 random number 
     or numerical imprecision
     """                                                                            
     cdef int i,j,n1
     cdef bint sampld
     cdef double u1,u2,alu1,fx
     cdef double alhl, alhu
     cdef int max_attempt = 3*ns                                                             
     sampld=False             
     cdef int attempts=0
     while ((not sampld) and (attempts<max_attempt)):                                             
         u2=rand()/RAND_MAX                                                 
         #test for zero random number                                         
         if (u2==0.0):                                             
            ifault=6
            return
         splhull(u2,ipt,ilow,lb,xlb,hulb,huzmax,alcu,x,hx,hpx,z,huz,scum,eps,emax,beta,i,j)
         #sample u1 to compute rejection                                        
         u1=rand()/RAND_MAX                                                    
         if (u1==0.0):
            ifault=6                                         
         alu1=log(u1)                                                   
         # compute alhu: upper hull at point u1                                  
         alhu=hpx[i]*(beta-x[i])+hx[i]-huzmax                          
         if ((beta>x[ilow]) and (beta<x[ihigh])):              
            # compute alhl: value of the lower hull at point u1                     
            if (beta>x[i]):                                         
               j=i
               i=ipt[i]
            alhl=hx[i]+(beta-x[i])*(hx[i]-hx[i])/(x[i]-x[i])-huzmax 
            #squeezing test                                                        
            if ((alhl-alhu)>alu1):                                 
               sampld=True                                                                                                       
            #if not sampled evaluate the function ,do the rejection test and update
         if (not sampld):                                            
            n1=n+1                                                        
            x[n1]=beta                                                    
            hx[n1]=h(x[n1])
            hpx[n1]=hprima(x[n1])
            fx=hx[n1]-huzmax                                              
            if (alu1<(fx-alhu)):
               sampld=True                          
            # update while the number of points defining the hulls is lower than ns
            if (n<ns):
               update(n,ilow,ihigh,ipt,scum,cu,x,hx,hpx,z,huz,huzmax,emax,lb,xlb,hulb,ub,xub,huub,ifault,eps,alcu) 
            if (ifault!=0):
               return  
         attempts+=1 
     if (attempts >= max_attempt):
        raise ValueError("Trap in ARS: Maximum number of attempts reached by routine spl1_\n") 
     return                                                            

cdef void splhull(double u2, np.ndarray[int, mode="c", ndim=1] ipt, int ilow, bint lb, double xlb, double hulb, double huzmax, double alcu, np.ndarray[double, mode="c", ndim=1] x, np.ndarray[double, mode="c", ndim=1] hx, np.ndarray[double, mode="c", ndim=1] hpx, np.ndarray[double, mode="c", ndim=1] z, np.ndarray[double, mode="c", ndim=1] huz, np.ndarray[double, mode="c", ndim=1] scum, double eps, double emax, double beta, int i, int j):
      #this subroutine samples beta from the normalised upper hull
      cdef double eh,logdu,logtg,sign
      cdef bint horiz
      #
      i=ilow                                                          
      #
      #find from which exponential piece you sample                        
      while (u2>scum[i]):                                          
        j=i                                                           
        i=ipt[i]                                                      

      if (i==ilow):                                             
        #sample below z(ilow),depending on the existence of a lower bound  
        if (lb) :                                            
          eh=hulb-huzmax-alcu                                         
          horiz=(fabs(hpx[ilow])<eps)
          if (horiz):
             beta=xlb+u2*expon(-eh,emax)
          else:
             sign=fabs(hpx[i])/hpx[i]
             logtg=log(fabs(hpx[i]))
             logdu=log(u2)
             eh=logdu+logtg-eh
             if (eh<emax): 
                beta=xlb+log(1.0+sign*expon(eh,emax))/hpx[i]
             else:
                beta=xlb+eh/hpx[i]
        else:                                                          
          #hpx(i) must be positive , x(ilow) is left of the mode
          beta=(log(hpx[i]*u2)+alcu-hx[i]+x[i]*hpx[i]+huzmax)/hpx[i]   

      else:                                                           
        #sample above(j)                                                   
        eh=huz[j]-huzmax-alcu                                         
        horiz=(fabs(hpx[i])<eps)
        if (horiz):
           beta=z[j]+(u2-scum[j])*expon(-eh,emax)
        else:
            sign=fabs(hpx[i])/hpx[i]
            logtg=log(fabs(hpx[i]))
            logdu=log(u2-scum[j])
            eh=logdu+logtg-eh
            if (eh<emax):
              beta=z[j]+(log(1.0+sign*expon(eh,emax)))/hpx[j] 
            else:
              beta=z[j]+eh/hpx[j]
      return

cdef void intersection(double x1,double y1,double yp1,double x2,double y2,double yp2,double z1,double hz1,double eps,int ifault):    
     """                                                                   
     computes the intersection (z1,hz1) between 2 tangents defined by
     x1,y1,yp1 and x2,y2,yp2
     """
     cdef double y12,y21,dh
     # first test for non-concavity                                          
     y12=y1+yp1*(x2-x1)
     y21=y2+yp2*(x1-x2)
     if ((y21<y1) or (y12<y2)):                                    
         ifault=5                                                       
         return                                                       

     dh=yp2-yp1                                                        
     #IF the lines are nearly parallel,
     #the intersection is taken at the midpoint
     if (fabs(dh)<=eps):                                           
        z1=0.5*(x1+x2)                                                  
        hz1=0.5*(y1+y2)                                                 
     #Else compute from the left or the right for greater numerical precision
     elif (fabs(yp1)<fabs(yp2)):                               
        z1=x2+(y1-y2+yp1*(x2-x1))/dh                                    
        hz1=yp1*(z1-x1)+y1                                              
     else:                                                          
        z1=x1+(y1-y2+yp2*(x2-x1))/dh                                    
        hz1=yp2*(z1-x2)+y2                                              

     #test for misbehaviour due to numerical imprecision
     if ((z1<x1) or (z1>x2)): 
        ifault=7
     return                                                            

cdef void update(int n,int ilow,int ihigh,np.ndarray[int, mode="c", ndim=1] ipt,np.ndarray[double, mode="c", ndim=1] scum,double cu,np.ndarray[double, mode="c", ndim=1] x,np.ndarray[double, mode="c", ndim=1] hx,np.ndarray[double, mode="c", ndim=1] hpx,np.ndarray[double, mode="c", ndim=1] z,np.ndarray[double, mode="c", ndim=1] huz,double huzmax,double emax,bint lb,double xlb,double hulb,bint ub,double xub,double huub,int ifault,double eps,double alcu):
      """                                                                       
       this subroutine increments n and updates all the parameters which
       define the lower and the upper hull
      """                                                                       
      cdef int i,j
      cdef bint horiz
      cdef double dh,u
      cdef double zero=1e-2
      """

      DESCRIPTION OF PARAMETERS and place of storage

      ilow iwv[0]    : index of the smallest x(i)
      ihigh iwv[1]   : index of the largest x(i)
      n    iwv[3]    : number of points defining the hulls
      ipt  iwv[iipt] : pointer array:  ipt(i) is the index of the x(.) 
                       immediately larger than x(i)
      hulb rwv[0]    : value of the upper hull at xlb
      huub rwv[1]    : value of the upper hull at xub
      cu   rwv[4]    : integral of the exponentiated upper hull divided
                       by exp(huzmax)
      alcu rwv[5]    : logarithm of cu
      huzmax rwv[6]  : maximum of huz(i); i=1,n
      z    rwv[iz+1] : z(i) is the abscissa of the intersection between
                       the tangents at x(i) and x(ipt(i))
      huz  rwv[ihuz+1]: huz(i) is the ordinate of the intersection
                         defined above
      scum rwv[iscum]: scum(i) is the cumulative probability of the 
                       normalised exponential of the upper hull 
                       calculated at z(i)
      eps  rwv[3]    : =exp(-emax) a very small number 
      """
      n=n+1                                                             
      #update z,huz and ipt                                                  
      if (x[n]<x[ilow]):                                          
         #insert x(n) below x(ilow)                                             
         #test for non-concavity                                              
         if (hpx[ilow]>hpx[n]): 
             ifault=5 
         ipt[n]=ilow                                                     
         intersection(x[n],hx[n],hpx[n],x[ilow],hx[ilow],hpx[ilow], z[n],huz[n],eps,ifault)                               
         if (ifault!=0):
             return
         if (lb):
            hulb=hpx[n]*(xlb-x[n])+hx[n]                      
         ilow=n                                                          
      else:                                                              
        i=ilow                                                          
        j=i                                                             
        #find where to insert x(n)                                             
        while ((x[n]>=x[i]) and (ipt[i]!=0)):                       
          j=i                                                           
          i=ipt[i]                                                     
        if (x[n]>x[i]):                                          
           # insert above x(ihigh)                                                 
           # test for non-concavity                                              
           if (hpx[i]<hpx[n]):
              ifault=5
           ihigh=n                                                       
           ipt[i]=n                                                      
           ipt[n]=0                                                      
           intersection(x[i],hx[i],hpx[i],x[n],hx[n],hpx[n],z[i],huz[i],eps,ifault)                               
           if (ifault!=0): 
              return
           huub=hpx[n]*(xub-x[n])+hx[n]                                  
           z[n]=0.0                                                      
           huz[n]=0.0
        else:                                                            
           # insert x(n) between x(j) and x(i)                                     
           # test for non-concavity                                              
           if ((hpx[j]<hpx[n]) or (hpx[i]>hpx[n])):
              ifault=5
           ipt[j]=n                                                      
           ipt[n]=i                                                      
           # insert z(j) between x(j) and x(n)                                 
           intersection(x[j],hx[j],hpx[j],x[n],hx[n],hpx[n],z[j],huz[j],eps,ifault)                               
           if (ifault!=0):
              return
           #insert z(n) between x(n) and x(i)                                 
           intersection(x[n],hx[n],hpx[n],x[i],hx[i],hpx[i],z[n],huz[n],eps,ifault)                               
           if (ifault!=0):
              return
      #update huzmax                                                         
      j=ilow                                                            
      i=ipt[j]                                                          
      huzmax=huz[j]                                                     
      while ((huz[j]<huz[i]) and (ipt[i]!=0)):                     
        j=i                                                             
        i=ipt[i]                                                        
        huzmax=max(huzmax,huz[j])                                     
      if (lb):
          huzmax=max(huzmax,hulb)
      if (ub):
          huzmax=max(huzmax,huub)
      #update cu                                                             
      #scum receives area below exponentiated upper hull left of z(i)       
      i=ilow                                                            
      horiz=(fabs(hpx[ilow])<eps)
      if ((not lb) and (not horiz)):
        cu=expon(huz[i]-huzmax,emax)/hpx[i]                             
      elif (lb and horiz):
        cu=(z[ilow]-xlb)*expon(hulb-huzmax,emax)
      elif (lb and (not horiz)): 
        dh=hulb-huz[i]
        if (dh>emax):
          cu=-expon(hulb-huzmax,emax)/hpx[i]
        else:
          cu=expon(huz[i]-huzmax,emax)*(1-expon(dh,emax))/hpx[i]
      else:
        cu=0
      scum[i]=cu                                                        
      j=i                                                               
      i=ipt[i] 
      cdef int control_count = 0
      while (ipt[i]!=0):     
        if (control_count>n):
           raise ValueError('Trap in ARS: infinite while in update near ...\n')
        control_count+=1
        dh=huz[j]-huz[i]                                                
        horiz=(fabs(hpx[i])<eps)
        if (horiz):
          cu+= (z[i]-z[j])*expon((huz[i]+huz[j])*0.5-huzmax,emax)
        else:
          if (dh<emax):                                          
            cu+= expon(huz[i]-huzmax,emax)*(1-expon(dh,emax))/hpx[i] 
          else:                                                          
            cu-= expon(huz[j]-huzmax,emax)/hpx[i]                      
        j=i                                                             
        i=ipt[i]                                                        
        scum[j]=cu                                                      
      horiz=(fabs(hpx[i])<eps)
      #if the derivative is very small the tangent is nearly horizontal
      if (not(ub or horiz)):
         cu -= expon(huz[j]-huzmax,emax)/hpx[i]                          
      elif (ub and horiz):
         cu += (xub-x[i])*expon((huub+hx[i])*0.5-huzmax,emax)
      elif (ub and (not horiz)):
         dh = huz[j]-huub
         if (dh>emax):
          cu -= expon(huz[j]-huzmax,emax)/hpx[i]
         else:
          cu += expon(huub-huzmax,emax)*(1-expon(dh,emax))/hpx[i]        
      scum[i]=cu                                                        
      if (cu>0):
         alcu=log(cu)
      #normalize scum to obtain a cumulative probability while excluding     
      #unnecessary points                                                 
      i=ilow                                                            
      u=(cu-scum[i])/cu                                                 
      if ((u==1.0) and (hpx[ipt[i]]>zero)):
        ilow=ipt[i]                                                     
        scum[i]=0.0                                                     
      else:                                                              
        scum[i]=1.0-u                                                   
      j=i                                                               
      i=ipt[i]                                                          
      while (ipt[i]!=0):                                             
        j=i                                                             
        i=ipt[i]                                                        
        u=(cu-scum[j])/cu                                               
        if ((u==1.0) and (hpx[i]>zero)):  
          ilow=i                                                        
        else:                                                            
          scum[j]=1.0-u                                                 
      scum[i]=1.0
      if (ub):
          huub=hpx[ihigh]*(xub-x[ihigh])+hx[ihigh]
      if (lb):
          hulb=hpx[ilow]*(xlb-x[ilow])+hx(ilow)                     
      return                                                            


cdef double expon(double x, double emax):   
     #performs an exponential without underflow
     cdef double expon
     if (x<-emax):                                          
        expon=0.0                                                       
     else:                                                           
        expon=exp(x)                                                   
     return expon                                                            

setup.py

代码语言:javascript
复制
from distutils.core import setup
from distutils.extension import Extension

import numpy
from Cython.Distutils import build_ext
extra_compile_args = ['-fPIC']
extra_link_args = ['-Wall']
setup(
    cmdclass = {'build_ext': build_ext},
    ext_modules=[
        Extension("AdaptiveRejectionSampling", 
                  sources=["AdaptiveRejectionSampling.pyx"],
                  include_dirs=[numpy.get_include()],
                  extra_compile_args=extra_compile_args,
                  extra_link_args=extra_link_args)
    ]
)

更新: Test.py

代码语言:javascript
复制
import numpy as np
import ctypes
from ars import *
m=3
ns=100
emax=64
x=np.zeros(10, float)
hx=np.zeros(10, float)
hpx=np.zeros(10, float)
x[0]=0
x[1]=1.0
x[2]=-1.0
rwv=np.zeros(700, float)
iwv=np.zeros(200, np.int64)
def normal(x):
    return -x*x*0.5,-x

hx[0],hpx[0]=normal(x[0])
hx[1],hpx[1]=normal(x[1])
hx[2],hpx[2]=normal(x[2])
testlib = ctypes.cdll.LoadLibrary('./ars.so')
class Data(ctypes.Structure): 
       _fields_ = [("x", ctypes.POINTER(ctypes.c_double)),
                   ("hx", ctypes.POINTER(ctypes.c_double)),
                   ("hpx", ctypes.POINTER(ctypes.c_double))]


data = Data(np.ctypeslib.as_ctypes(x),
            np.ctypeslib.as_ctypes(hx),
            np.ctypeslib.as_ctypes(hpx))

class Bounds(ctypes.Structure): 
       _fields_ = [("lb", ctypes.c_bool),
                   ("xlb", ctypes.c_double),                   
                   ("ub", ctypes.c_bool),
                   ("xub", ctypes.c_double),
                   ("ifault", ctypes.c_int)]

b =  Bounds(lb=False,ub=False,ifault=0)
iwv=np.zeros(200,int)
rwv=np.zeros(700,float)
initial( ns, m, emax, data.x, data.hx, data.hpx, b.lb, b.xlb, b.ub, b.xub, b.ifault, iwv, rwv)

def h(x):
    yu=-x*x*0.5 
    return yu

def hprima(x)
    ypu=-x
    return ypu

num=200
sp=np.empty(num, dtype=float)
for i in range(num):
       sample(iwv,rwv,h,hprima,sim,b.ifault)
       sp[i]= sim

Update:我希望代码通过在循环中运行sim变量的num次数来返回不同的值,或者准确地说,sim的值被更新并保存在sp数组中。

EN

回答 1

Code Review用户

回答已采纳

发布于 2017-09-04 22:11:28

Cython指令

使用以下Cython指令将从代码的许多部分中移除大量开销,并使其更快,尽管它传递了更多的响应性,以便您在调用这些函数之前检查输入(这是一件好事)。

代码语言:javascript
复制
#cython: wraparound=False
#cython: boundscheck=False
#cython: cdivision=True
#cython: nonecheck=False

原始指针或内存视图

使用C原始指针而不是np.ndarray类。这使您可以轻松地使代码的很大一部分成为本机C,因此,如果需要的话,可以更容易地移植。再次检查<int>铸件,如果它们有任何副作用,但它应该工作得很好。

No GIL

在使用了原始指针(或者作为Cython的内存视图的另一个选项)之后,您可以发布Python并以nogil的形式运行一切。

下面是对产生的代码的建议:

代码语言:javascript
复制
#cython: wraparound=False
#cython: boundscheck=False
#cython: cdivision=True
#cython: nonecheck=False

import cython
import numpy as np
import ctypes
cimport numpy as np
cdef extern from "math.h":
     cpdef double log(double x) nogil
     cpdef double exp(double x) nogil

cdef extern from "stdlib.h":
     cpdef int rand() nogil
     cpdef enum: RAND_MAX

from libc.math cimport fabs


ctypedef double (*f_type)(double) nogil

cdef void initial(int ns, int m, double emax, double* x, double* hx, double*
        hpx, bint lb, double xlb, bint ub, double xub, int ifault, int* iwv,
        double* rwv) nogil:
      """
      This subroutine takes as input the number of starting values m
      and the starting values x(i), hx(i), hpx(i)  i = 1, m
      As output we have pointer iipt along with ilow and ihigh and the lower
      and upper hulls defined  by z, hz, scum, cu, hulb, huub stored in working
      vectors iwv and rwv
      Ifault detects wrong starting points or non-concavity
      ifault codes, subroutine initial
      0:successful initialisation
      1:not enough starting points
      2:ns is less than m
      3:no abscissae to left of mode (if lb = false)
      4:no abscissae to right of mode (if ub = false)
      5:non-log-concavity detect
      """
      cdef int nn, ilow, ihigh, i
      cdef int iipt, iz, ihuz, iscum, ix, ihx, ihpx
      cdef bint horiz
      cdef double hulb, huub, eps, cu, alcu, huzmax
      """
      DESCRIPTION OF PARAMETERS and place of storage

      lb   iwv[4] : boolean indicating if there is a lower bound to the
                     domain
      ub   iwv[5] : boolean indicating if there is an upper bound
      xlb  rwv[7] : value of the lower bound
      xub  rwv[8] : value of the upper bound
      emax rwv[2] : large value for which it is possible to compute
                    an exponential, eps = exp(-emax) is taken as a small
                    value used to test for numerical unstability
      m    iwv[3] : number of starting points
      ns   iwv[2] : maximum number of points defining the hulls
      x    rwv(ix+1)  : vector containing the abscissae of the starting
                        points
      hx   rwv(ihx+1) : vector containing the ordinates
      hpx  rwv(ihpx+1): vector containing the derivatives
      ifault      : diagnostic
      iwv, rwv     : integer and real working vectors
      """
      eps = expon(-emax, emax)
      ifault = 0
      ilow = 0
      ihigh = 0
      nn = ns+1
      #at least one starting point
      if (m < 1):
         ifault = 1

      huzmax = hx[0]
      if not ub:
         xub = 0.0

      if not lb:
         xlb = 0.0

      hulb = (xlb-x[0])*hpx[0] + hx[0]
      huub = (xub-x[0])*hpx[0] + hx[0]
      #if bounded on both sides
      if (ub and lb):
         huzmax = max(huub, hulb)
         horiz = (fabs(hpx[0]) < eps)
         if (horiz):
           cu = expon((huub+hulb)*0.5-huzmax, emax)*(xub-xlb)
         else:
           cu = expon(huub-huzmax, emax)*(1-expon(hulb-huub, emax))/hpx[0]
      elif ((ub) and (not lb)):
         #if bounded on the right and unbounded on the left
         huzmax = huub
         cu = 1.0/hpx[0]

      elif ((not ub) and (lb)):
         #if bounded on the left and unbounded on the right
         huzmax = hulb
         cu = -1.0/hpx[0]

         #if unbounded at least 2 starting points
      else:
         cu = 0.0
         if (m < 2):
             ifault = 1

      if (cu > 0.0):
          alcu = log(cu)
      #set pointers
      iipt = 5
      iz = 8
      ihuz = nn+iz
      iscum = nn+ihuz
      ix = nn+iscum
      ihx = nn+ix
      ihpx = nn+ihx
      #store values in working vectors
      iwv[0] = ilow
      iwv[1] = ihigh
      iwv[2] = ns
      iwv[3] = 1
      if lb:
         iwv[4] = 1
      else:
         iwv[4] = 0

      if ub:
         iwv[5] = 1
      else:
         iwv[5] = 0

      if ( ns < m):
         ifault = 2

      iwv[iipt+1] = 0
      rwv[0] = hulb
      rwv[1] = huub
      rwv[2] = emax
      rwv[3] = eps
      rwv[4] = cu
      rwv[5] = alcu
      rwv[6] = huzmax
      rwv[7] = xlb
      rwv[8] = xub
      rwv[iscum+1] = 1.0
      for i in range(m):
         rwv[ix+i] = x[i]
         rwv[ihx+i] = hx[i]
         rwv[ihpx+i] = hpx[i]
      #create lower and upper hulls
      i = 0
      while (i < m):
            update(iwv[3], iwv[0], iwv[1], &iwv[iipt+1], &rwv[iscum+1], rwv[4],
                    &rwv[ix+1], &rwv[ihx+1], &rwv[ihpx+1], &rwv[iz+1],
                    &rwv[ihuz+1], rwv[6], rwv[2], lb, rwv[7], rwv[0], ub,
                    rwv[8], rwv[1], ifault, rwv[3], rwv[5])
            i = iwv[3]
            if (ifault != 0):
               return

      #test for wrong starting points
      if ((not lb) and (hpx[iwv[0]] < eps)):
         ifault = 3
      if ((not ub) and (hpx[iwv[1]] > -eps)):
         ifault = 4
      return


cdef void sample(double* iwv, double* rwv, f_type h, f_type hprima, double beta, int ifault) nogil:
      """
      ifault
      0:successful sampling
      5:non-concavity detected
      6:random number generator generated zero
      7:numerical instability
      """
      cdef int iipt, iz, ns, nn, ihuz, iscum, ix, ihx, ihpx
      cdef bint ub, lb

      #set pointers
      iipt = 5
      iz = 8
      ns = <int>iwv[2]
      nn = ns+1
      ihuz = nn+iz
      iscum = nn+ihuz
      ix = nn+iscum
      ihx = nn+ix
      ihpx = nn+ihx
      lb = False
      ub = False
      if (iwv[4] == 1):
         lb = True
      if (iwv[5] == 1):
         ub = True

      #call sampling subroutine
      spl1(ns, <int>iwv[3], <int>iwv[0], <int>iwv[1], <int *>&iwv[iipt+1], &rwv[iscum+1], rwv[4],
              &rwv[ix+1], &rwv[ihx+1], &rwv[ihpx+1], &rwv[iz+1], &rwv[ihuz+1],
              rwv[6], lb, rwv[7], rwv[0], ub, rwv[8], rwv[1], h, hprima, beta,
              ifault, rwv[2], rwv[3], rwv[5])
      return

cdef void spl1(int ns, int n, int ilow, int ihigh, int* ipt, double* scum,
        double cu, double* x, double* hx, double* hpx, double* z, double* huz,
        double huzmax, bint lb, double xlb, double hulb, bint ub, double xub,
        double huub, f_type h, f_type hprima, double beta, int ifault, double
        emax, double eps, double alcu) nogil:
     """
     this subroutine performs the adaptive rejection sampling, it calls
     subroutine splhull to sample from the upper hull, if the sampling
     involves a function evaluation it calls the updating subroutine
     ifault is a diagnostic of any problem: non concavity, 0 random number
     or numerical imprecision
     """
     cdef int i, j, n1
     cdef bint sampld
     cdef double u1, u2, alu1, fx
     cdef double alhl, alhu
     cdef int max_attempt = 3*ns
     sampld = False
     cdef int attempts = 0
     while ((not sampld) and (attempts < max_attempt)):
         u2 = rand()/RAND_MAX
         #test for zero random number
         if (u2 == 0.0):
            ifault = 6
            return
         splhull(u2, ipt, ilow, lb, xlb, hulb, huzmax, alcu, x, hx, hpx, z, huz, scum, eps, emax, beta, i, j)
         #sample u1 to compute rejection
         u1 = rand()/RAND_MAX
         if (u1 == 0.0):
            ifault = 6
         alu1 = log(u1)
         # compute alhu: upper hull at point u1
         alhu = hpx[i]*(beta-x[i])+hx[i]-huzmax
         if ((beta > x[ilow]) and (beta < x[ihigh])):
            # compute alhl: value of the lower hull at point u1
            if (beta > x[i]):
               j = i
               i = ipt[i]
            alhl = hx[i]+(beta-x[i])*(hx[i]-hx[i])/(x[i]-x[i])-huzmax
            #squeezing test
            if ((alhl-alhu) > alu1):
               sampld = True
            #if not sampled evaluate the function, do the rejection test and update
         if (not sampld):
            n1 = n+1
            x[n1]=beta
            hx[n1]=h(x[n1])
            hpx[n1] = hprima(x[n1])
            fx = hx[n1]-huzmax
            if (alu1 < (fx-alhu)):
               sampld = True
            # update while the number of points defining the hulls is lower than ns
            if (n < ns):
               update(n, ilow, ihigh, ipt, scum, cu, x, hx, hpx, z, huz, huzmax, emax, lb, xlb, hulb, ub, xub, huub, ifault, eps, alcu)
            if (ifault != 0):
               return
         attempts += 1
     if (attempts >= max_attempt):
        with gil:
           raise ValueError("Trap in ARS: Maximum number of attempts reached by routine spl1_\n")
     return

cdef void splhull(double u2, int* ipt, int ilow,
        bint lb, double xlb, double hulb, double huzmax, double alcu,
        double* x, double* hx, double* hpx,
        double* z, double* huz, double* scum, double eps,
        double emax, double beta, int i, int j) nogil:
      #this subroutine samples beta from the normalised upper hull
      cdef double eh, logdu, logtg, sign
      cdef bint horiz
      #
      i = ilow
      #
      #find from which exponential piece you sample
      while (u2 > scum[i]):
        j = i
        i = ipt[i]

      if (i==ilow):
        #sample below z(ilow), depending on the existence of a lower bound
        if (lb) :
          eh = hulb-huzmax-alcu
          horiz = (fabs(hpx[ilow]) < eps)
          if (horiz):
             beta = xlb+u2*expon(-eh, emax)
          else:
             sign = fabs(hpx[i])/hpx[i]
             logtg = log(fabs(hpx[i]))
             logdu = log(u2)
             eh = logdu+logtg-eh
             if (eh < emax):
                beta = xlb+log(1.0+sign*expon(eh, emax))/hpx[i]
             else:
                beta = xlb+eh/hpx[i]
        else:
          #hpx(i) must be positive, x(ilow) is left of the mode
          beta = (log(hpx[i]*u2)+alcu-hx[i]+x[i]*hpx[i]+huzmax)/hpx[i]

      else:
        #sample above(j)
        eh = huz[j]-huzmax-alcu
        horiz = (fabs(hpx[i]) < eps)
        if (horiz):
           beta = z[j]+(u2-scum[j])*expon(-eh, emax)
        else:
            sign = fabs(hpx[i])/hpx[i]
            logtg = log(fabs(hpx[i]))
            logdu = log(u2-scum[j])
            eh = logdu+logtg-eh
            if (eh < emax):
              beta = z[j]+(log(1.0+sign*expon(eh, emax)))/hpx[j]
            else:
              beta = z[j]+eh/hpx[j]
      return

cdef void intersection(double x1, double y1, double yp1, double x2, double y2,
        double yp2, double z1, double hz1, double eps, int ifault) nogil:
     """
     computes the intersection (z1, hz1) between 2 tangents defined by
     x1, y1, yp1 and x2, y2, yp2
     """
     cdef double y12, y21, dh
     # first test for non-concavity
     y12 = y1+yp1*(x2-x1)
     y21 = y2+yp2*(x1-x2)
     if ((y21 < y1) or (y12 < y2)):
         ifault = 5
         return

     dh = yp2-yp1
     #IF the lines are nearly parallel,
     #the intersection is taken at the midpoint
     if (fabs(dh) <= eps):
        z1 = 0.5*(x1+x2)
        hz1 = 0.5*(y1+y2)
     #Else compute from the left or the right for greater numerical precision
     elif (fabs(yp1) < fabs(yp2)):
        z1 = x2+(y1-y2+yp1*(x2-x1))/dh
        hz1 = yp1*(z1-x1)+y1
     else:
        z1 = x1+(y1-y2+yp2*(x2-x1))/dh
        hz1 = yp2*(z1-x2)+y2

     #test for misbehaviour due to numerical imprecision
     if ((z1 < x1) or (z1 > x2)):
        ifault = 7
     return

cdef void update(int n, int ilow, int ihigh, int* ipt, double* scum, double cu,
        double* x, double* hx, const double* hpx, double* z, double* huz, double
        huzmax, double emax, bint lb, double xlb, double hulb, bint ub, double
        xub, double huub, int ifault, double eps, double alcu) nogil:
      """
       this subroutine increments n and updates all the parameters which
       define the lower and the upper hull
      """
      cdef int i, j
      cdef bint horiz
      cdef double dh, u
      cdef double zero = 1e-2
      """

      DESCRIPTION OF PARAMETERS and place of storage

      ilow iwv[0]    : index of the smallest x(i)
      ihigh iwv[1]   : index of the largest x(i)
      n    iwv[3]    : number of points defining the hulls
      ipt  iwv[iipt] : pointer array:  ipt(i) is the index of the x(.)
                       immediately larger than x(i)
      hulb rwv[0]    : value of the upper hull at xlb
      huub rwv[1]    : value of the upper hull at xub
      cu   rwv[4]    : integral of the exponentiated upper hull divided
                       by exp(huzmax)
      alcu rwv[5]    : logarithm of cu
      huzmax rwv[6]  : maximum of huz(i); i = 1, n
      z    rwv[iz+1] : z(i) is the abscissa of the intersection between
                       the tangents at x(i) and x(ipt(i))
      huz  rwv[ihuz+1]: huz(i) is the ordinate of the intersection
                         defined above
      scum rwv[iscum]: scum(i) is the cumulative probability of the
                       normalised exponential of the upper hull
                       calculated at z(i)
      eps  rwv[3]    : =exp(-emax) a very small number
      """
      n = n+1
      #update z, huz and ipt
      if (x[n] < x[ilow]):
         #insert x(n) below x(ilow)
         #test for non-concavity
         if (hpx[ilow] > hpx[n]):
             ifault = 5
         ipt[n]=ilow
         intersection(x[n], hx[n], hpx[n], x[ilow], hx[ilow], hpx[ilow], z[n], huz[n], eps, ifault)
         if (ifault != 0):
             return
         if (lb):
            hulb = hpx[n]*(xlb-x[n])+hx[n]
         ilow = n
      else:
        i = ilow
        j = i
        #find where to insert x(n)
        while ((x[n]>=x[i]) and (ipt[i] != 0)):
          j = i
          i = ipt[i]
        if (x[n] > x[i]):
           # insert above x(ihigh)
           # test for non-concavity
           if (hpx[i] < hpx[n]):
              ifault = 5
           ihigh = n
           ipt[i] = n
           ipt[n] = 0
           intersection(x[i], hx[i], hpx[i], x[n], hx[n], hpx[n], z[i], huz[i], eps, ifault)
           if (ifault != 0):
              return
           huub = hpx[n]*(xub-x[n])+hx[n]
           z[n] = 0.0
           huz[n] = 0.0
        else:
           # insert x(n) between x(j) and x(i)
           # test for non-concavity
           if ((hpx[j] < hpx[n]) or (hpx[i] > hpx[n])):
              ifault = 5
           ipt[j]=n
           ipt[n]=i
           # insert z(j) between x(j) and x(n)
           intersection(x[j], hx[j], hpx[j], x[n], hx[n], hpx[n], z[j], huz[j], eps, ifault)
           if (ifault != 0):
              return
           #insert z(n) between x(n) and x(i)
           intersection(x[n], hx[n], hpx[n], x[i], hx[i], hpx[i], z[n], huz[n], eps, ifault)
           if (ifault != 0):
              return
      #update huzmax
      j = ilow
      i = ipt[j]
      huzmax = huz[j]
      while ((huz[j] < huz[i]) and (ipt[i] != 0)):
        j = i
        i = ipt[i]
        huzmax = max(huzmax, huz[j])
      if (lb):
          huzmax = max(huzmax, hulb)
      if (ub):
          huzmax = max(huzmax, huub)
      #update cu
      #scum receives area below exponentiated upper hull left of z(i)
      i = ilow
      horiz = (fabs(hpx[ilow]) < eps)
      if ((not lb) and (not horiz)):
        cu = expon(huz[i]-huzmax, emax)/hpx[i]
      elif (lb and horiz):
        cu = (z[ilow]-xlb)*expon(hulb-huzmax, emax)
      elif (lb and (not horiz)):
        dh = hulb-huz[i]
        if (dh > emax):
          cu = -expon(hulb-huzmax, emax)/hpx[i]
        else:
          cu = expon(huz[i]-huzmax, emax)*(1-expon(dh, emax))/hpx[i]
      else:
        cu = 0
      scum[i]=cu
      j = i
      i = ipt[i]
      cdef int control_count = 0
      while (ipt[i] != 0):
        if (control_count > n):
           with gil:
              raise ValueError('Trap in ARS: infinite while in update near ...\n')
        control_count += 1
        dh = huz[j]-huz[i]
        horiz = (fabs(hpx[i]) < eps)
        if (horiz):
          cu += (z[i]-z[j])*expon((huz[i]+huz[j])*0.5-huzmax, emax)
        else:
          if (dh < emax):
            cu += expon(huz[i]-huzmax, emax)*(1-expon(dh, emax))/hpx[i]
          else:
            cu -= expon(huz[j]-huzmax, emax)/hpx[i]
        j = i
        i = ipt[i]
        scum[j]=cu
      horiz = (fabs(hpx[i]) < eps)
      #if the derivative is very small the tangent is nearly horizontal
      if (not(ub or horiz)):
         cu -= expon(huz[j]-huzmax, emax)/hpx[i]
      elif (ub and horiz):
         cu += (xub-x[i])*expon((huub+hx[i])*0.5-huzmax, emax)
      elif (ub and (not horiz)):
         dh = huz[j]-huub
         if (dh > emax):
          cu -= expon(huz[j]-huzmax, emax)/hpx[i]
         else:
          cu += expon(huub-huzmax, emax)*(1-expon(dh, emax))/hpx[i]
      scum[i]=cu
      if (cu > 0):
         alcu = log(cu)
      #normalize scum to obtain a cumulative probability while excluding
      #unnecessary points
      i = ilow
      u = (cu-scum[i])/cu
      if ((u == 1.0) and (hpx[ipt[i]] > zero)):
        ilow = ipt[i]
        scum[i] = 0.0
      else:
        scum[i] = 1.0-u
      j = i
      i = ipt[i]
      while (ipt[i] != 0):
        j = i
        i = ipt[i]
        u = (cu-scum[j])/cu
        if ((u == 1.0) and (hpx[i] > zero)):
          ilow = i
        else:
          scum[j] = 1.0 - u
      scum[i] = 1.0
      if (ub):
          huub = hpx[ihigh]*(xub-x[ihigh])+hx[ihigh]
      if (lb):
          hulb = hpx[ilow]*(xlb-x[ilow])+hx[ilow]
      return


cdef double expon(double x, double emax) nogil:
     #performs an exponential without underflow
     cdef double expon
     if (x < -emax):
        expon = 0.0
     else:
        expon = exp(x)
     return expon
票数 6
EN
页面原文内容由Code Review提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://codereview.stackexchange.com/questions/174814

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档