我使用Numba0.46.0,我想将类中的一个对象作为参数传递给我的函数,并使用CUDA在我的GPU上运行这个函数。如果我想使用一个简单的Python (比如int),我会使用如下内容:
from numba import jit, cuda
from numba.types import void, int32
@jit(void(int32), target='cuda')
def f(int_object):
pass
f(123)这个很好用。现在我试着对一个类做同样的事情:
from numba import jit, cuda
from numba,types import void
@jitclass([])
class MyClass:
def __init__(self):
pass
@jit(void(MyClass), target='cuda')
def f(MyClass_object):
pass对于没有任何评论的NotImplementedError来说,这是失败的。我还试图以一种懒散的方式编译它:
@jit(target='cuda')
def f(MyClass_object):
pass
f(MyClass())这是因为
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/usr/local/lib/python3.6/dist-packages/numba/cuda/dispatcher.py", line 42, in __call__
return self.compiled(*args, **kws)
File "/usr/local/lib/python3.6/dist-packages/numba/cuda/compiler.py", line 801, in __call__
cfg(*args)
File "/usr/local/lib/python3.6/dist-packages/numba/cuda/compiler.py", line 537, in __call__
sharedmem=self.sharedmem)
File "/usr/local/lib/python3.6/dist-packages/numba/cuda/compiler.py", line 604, in _kernel_call
self._prepare_args(t, v, stream, retr, kernelargs)
File "/usr/local/lib/python3.6/dist-packages/numba/cuda/compiler.py", line 715, in _prepare_args
raise NotImplementedError(ty, val)
NotImplementedError: (instance.jitclass.MyClass#7f983418fc88<>, <numba.jitclass.boxing.MyClass object at 0x7f983416ca10>)我可以使用jitclass对象作为jit函数参数吗?如果是,上面的例子有什么问题?
UPD:顺便说一下,我已经用numpy数组检查过了:
import numpy as np
from numba import jit, cuda
from numba.types import void
@jit(void(np.ndarray), target='cuda')
def f1(ndarray_object):
pass
# Fails with NotImplementedError with no comments
@jit(target='cuda')
def f2(ndarray_object):
pass
a = np.asarray([])
f2(a) # Executes with no errors, only a warning about autojit为什么这个方法适用于numpy,而不适用于我的班级?为什么这适用于惰性模式(f2)中的numpy,而不适用于给定的签名(f1)?
发布于 2019-12-03 07:49:32
根据相关的文档 (编写本报告时的Numba0.47):
只在CPU上提供对jitclasses的支持。(注:计划在未来发布对GPU设备的支持。)
https://stackoverflow.com/questions/59123932
复制相似问题