我有这段代码,用Numba来加速处理。基本上,定义particle_dtype是为了使代码使用Numba运行。然而,据报道,TypingError说“无法确定的Numba类型”.我不知道问题出在哪里。
import numpy
from numba import njit
particle_dtype = numpy.dtype({'names':['x','y','z','m','phi'],
'formats':[numpy.double,
numpy.double,
numpy.double,
numpy.double,
numpy.double]})
def create_n_random_particles(n, m, domain=1):
parts = numpy.zeros((n), dtype=particle_dtype)
parts['x'] = numpy.random.random(size=n) * domain
parts['y'] = numpy.random.random(size=n) * domain
parts['z'] = numpy.random.random(size=n) * domain
parts['m'] = m
parts['phi'] = 0.0
return parts
def distance(se, other):
return numpy.sqrt(numpy.square(se['x'] - other['x']) +
numpy.square(se['y'] - other['y']) +
numpy.square(se['z'] - other['z']))
parts = create_n_random_particles(10, .001, 1)
@njit
def direct_sum(particles):
for i, target in enumerate(particles):
for j in range(particles.shape[0]):
if i == j:
continue
source = particles[j]
r = distance(target, source)
# target['phi'] += source['m'] / r
target['phi'] = target['phi'] + source['m'] / r
return(target['phi'])
print(direct_sum(parts) ) 我想是因为不受支持的函数或操作在某个地方使用,但我找不到它。谢谢你的帮助。
发布于 2022-03-09 12:12:14
direct_sum是一个JITed函数,它不能调用distance,因为它不是JITed (纯Python函数)。您还需要在distance上使用装饰器distance。
https://stackoverflow.com/questions/71405055
复制相似问题