首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Numba的TypingError

Numba的TypingError
EN

Stack Overflow用户
提问于 2022-03-09 06:14:33
回答 1查看 75关注 0票数 0

我有这段代码,用Numba来加速处理。基本上,定义particle_dtype是为了使代码使用Numba运行。然而,据报道,TypingError说“无法确定的Numba类型”.我不知道问题出在哪里。

代码语言:javascript
复制
import numpy
from numba import njit

particle_dtype = numpy.dtype({'names':['x','y','z','m','phi'], 
                             'formats':[numpy.double, 
                                        numpy.double, 
                                        numpy.double, 
                                        numpy.double, 
                                        numpy.double]}) 


def create_n_random_particles(n, m, domain=1):
    parts = numpy.zeros((n), dtype=particle_dtype)
    parts['x'] = numpy.random.random(size=n) * domain
    parts['y'] = numpy.random.random(size=n) * domain
    parts['z'] = numpy.random.random(size=n) * domain
    parts['m'] = m
    parts['phi'] = 0.0

    return parts


def distance(se, other):
    return numpy.sqrt(numpy.square(se['x'] - other['x']) + 
                      numpy.square(se['y'] - other['y']) + 
                      numpy.square(se['z'] - other['z']))


parts = create_n_random_particles(10, .001, 1)


@njit
def direct_sum(particles):
    for i, target in enumerate(particles):
        for j in range(particles.shape[0]):
            if i == j:
                continue
            source = particles[j]
            r = distance(target, source)
            # target['phi'] += source['m'] / r
            target['phi'] = target['phi'] + source['m'] / r
            return(target['phi'])
            
print(direct_sum(parts) ) 

我想是因为不受支持的函数或操作在某个地方使用,但我找不到它。谢谢你的帮助。

EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-03-09 12:12:14

direct_sum是一个JITed函数,它不能调用distance,因为它不是JITed (纯Python函数)。您还需要在distance上使用装饰器distance

票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71405055

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档