首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Numba兼容Numpy Meshgrid

Numba兼容Numpy Meshgrid
EN

Stack Overflow用户
提问于 2022-01-06 21:13:36
回答 1查看 707关注 0票数 1

我有一个从3D网格生成扁平数组的函数,如下所示。

我试图在代码中实现Numba,但是当Numba遇到这个函数时,它会抛出一个错误(因为Numba不支持Numpy的meshgrid或mgrid函数)。我有别的办法可以让这个Numba兼容吗?

代码语言:javascript
复制
def meshgrid_flat(max=1.0, sampling=100):
    #
    s = np.linspace(-max,max,sampling)
    X, Y, Z = np.meshgrid(s,s,s,indexing="ij")
    x, y, z = X.ravel(), Y.ravel(), Z.ravel()
    #
    return x, y, z
EN

回答 1

Stack Overflow用户

发布于 2022-01-06 22:04:52

你可以自己实现一个合适的网格,如果你愿意的话,你可以让它麻木。因为您想要一个带有索引ij的3D网格,下面的实现就可以了。

代码语言:javascript
复制
import numba

@numba.jit(nopython=True)
def meshgrid(x, y, z):
    xx = np.empty(shape=(x.size, y.size, z.size), dtype=x.dtype)
    yy = np.empty(shape=(x.size, y.size, z.size), dtype=y.dtype)
    zz = np.empty(shape=(x.size, y.size, z.size), dtype=z.dtype)
    for i in range(z.size):
        for j in range(y.size):
            for k in range(x.size):
                xx[i,j,k] = k  # change to x[k] if indexing xy
                yy[i,j,k] = j  # change to y[j] if indexing xy
                zz[i,j,k] = i  # change to z[i] if indexing xy
    return zz, yy, xx

这似乎比裸皮更快,因为它是特定于你的需要。

代码语言:javascript
复制
 x, y, z = np.arange(100), np.arange(100), np.arange(10)

 %timeit np.meshgrid(x, y, z, indexing="ij")
177 µs ± 9.07 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

 %timeit meshgrid(x, y, z)
47.3 µs ± 3.22 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

您还可以修改上面的实现以返回平面数组。只需将np.empty大小更改为三个向量大小的乘积,并更改循环中的索引:

代码语言:javascript
复制
@numba.jit(nopython=True)
def meshgrid_flat_3d(x):
    xx = np.empty(shape=(x.size * x.size * x.size), dtype=x.dtype)
    yy = np.empty_like(xx)
    zz = np.empty_like(xx)
    for i in range(x.size):
        for j in range(x.size):
            for k in range(x.size):
                xx[i*x.size**2 + j*x.size + k] = k
                yy[i*x.size**2 + j*x.size + k] = j  
                zz[i*x.size**2 + j*x.size + k] = i  
    return zz, yy, xx
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70613681

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档