首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >Numba数组函数指针

Numba数组函数指针
EN

Stack Overflow用户
提问于 2021-12-06 08:13:13
回答 1查看 143关注 0票数 1

我最近发现numba确实支持函数指针:如何通过将setter函数传递给jitclass本身来更新带有其字符串名称的jitclass变量?。是否可以构造一个指向函数的这样的指针数组?

这是一个MCVE:

代码语言:javascript
复制
from numba import float64
from numba.experimental import jitclass

t = None

@jitclass(spec={'a': float64,
                'ptrs': t[:]})
class Test:
    def __init__(self):
        self.a = 0
        self.ptrs = np.array([self.x, self.y])
    def x(self):
        self.a += 1
    def y(self):
        self.a += 2
    def increment(self, n):
        self.ptrs[n](self)

t = Test()
print(t.a)     # Desired: 0
t.increment(0)
print(t.a)     # Desired: 1
t.increment(1)
print(t.a)     # Desired: 3

显然,对于t = None,这会引发一个错误,即使没有指示数组的[:]索引。

如果设置t = void,则在行t = Test()上出现错误

代码语言:javascript
复制
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\experimental\jitclass\base.py", line 122, in __call__
    return cls._ctor(*bind.args[1:], **bind.kwargs)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\dispatcher.py", line 439, in _compile_for_args
    raise e
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\dispatcher.py", line 372, in _compile_for_args
    return_val = self.compile(tuple(argtypes))
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\dispatcher.py", line 909, in compile
    cres = self._compiler.compile(args, return_type)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\dispatcher.py", line 79, in compile
    status, retval = self._compile_cached(args, return_type)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\dispatcher.py", line 93, in _compile_cached
    retval = self._compile_core(args, return_type)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\dispatcher.py", line 106, in _compile_core
    cres = compiler.compile_extra(self.targetdescr.typing_context,
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\compiler.py", line 606, in compile_extra
    return pipeline.compile_extra(func)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\compiler.py", line 353, in compile_extra
    return self._compile_bytecode()
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\compiler.py", line 415, in _compile_bytecode
    return self._compile_core()
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\compiler.py", line 395, in _compile_core
    raise e
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\compiler.py", line 386, in _compile_core
    pm.run(self.state)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\compiler_machinery.py", line 339, in run
    raise patched_exception
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\compiler_machinery.py", line 330, in run
    self._runPass(idx, pass_inst, state)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\compiler_lock.py", line 35, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\compiler_machinery.py", line 289, in _runPass
    mutated |= check(pss.run_pass, internal_state)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\compiler_machinery.py", line 262, in check
    mangled = func(compiler_state)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\typed_passes.py", line 463, in run_pass
    NativeLowering().run_pass(state)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\typed_passes.py", line 386, in run_pass
    lower.create_cpython_wrapper(flags.release_gil)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\lowering.py", line 242, in create_cpython_wrapper
    self.context.create_cpython_wrapper(self.library, self.fndesc,
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\cpu.py", line 162, in create_cpython_wrapper
    builder.build()
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\callwrapper.py", line 122, in build
    self.build_wrapper(api, builder, closure, args, kws)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\callwrapper.py", line 187, in build_wrapper
    obj = api.from_native_return(retty, retval, env_manager)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\pythonapi.py", line 1396, in from_native_return
    out = self.from_native_value(typ, val, env_manager)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\pythonapi.py", line 1410, in from_native_value
    return impl(typ, val, c)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\experimental\jitclass\boxing.py", line 139, in _box_class_instance
    box_subclassed = _specialize_box(typ)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\experimental\jitclass\boxing.py", line 122, in _specialize_box
    fast_fget = fget.compile((typ,))
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\dispatcher.py", line 909, in compile
    cres = self._compiler.compile(args, return_type)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\dispatcher.py", line 79, in compile
    status, retval = self._compile_cached(args, return_type)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\dispatcher.py", line 93, in _compile_cached
    retval = self._compile_core(args, return_type)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\dispatcher.py", line 106, in _compile_core
    cres = compiler.compile_extra(self.targetdescr.typing_context,
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\compiler.py", line 606, in compile_extra
    return pipeline.compile_extra(func)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\compiler.py", line 353, in compile_extra
    return self._compile_bytecode()
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\compiler.py", line 415, in _compile_bytecode
    return self._compile_core()
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\compiler.py", line 395, in _compile_core
    raise e
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\compiler.py", line 386, in _compile_core
    pm.run(self.state)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\compiler_machinery.py", line 339, in run
    raise patched_exception
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\compiler_machinery.py", line 330, in run
    self._runPass(idx, pass_inst, state)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\compiler_lock.py", line 35, in _acquire_compile_lock
    return func(*args, **kwargs)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\compiler_machinery.py", line 289, in _runPass
    mutated |= check(pss.run_pass, internal_state)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\compiler_machinery.py", line 262, in check
    mangled = func(compiler_state)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\typed_passes.py", line 463, in run_pass
    NativeLowering().run_pass(state)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\typed_passes.py", line 386, in run_pass
    lower.create_cpython_wrapper(flags.release_gil)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\lowering.py", line 242, in create_cpython_wrapper
    self.context.create_cpython_wrapper(self.library, self.fndesc,
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\cpu.py", line 162, in create_cpython_wrapper
    builder.build()
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\callwrapper.py", line 122, in build
    self.build_wrapper(api, builder, closure, args, kws)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\callwrapper.py", line 187, in build_wrapper
    obj = api.from_native_return(retty, retval, env_manager)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\pythonapi.py", line 1396, in from_native_return
    out = self.from_native_value(typ, val, env_manager)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\pythonapi.py", line 1410, in from_native_value
    return impl(typ, val, c)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\boxing.py", line 397, in box_array
    np_dtype = numpy_support.as_dtype(typ.dtype)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\np\numpy_support.py", line 156, in as_dtype
    raise NotImplementedError("%r cannot be represented as a Numpy dtype"
NotImplementedError: Failed in nopython mode pipeline (step: nopython mode backend)
Failed in nopython mode pipeline (step: nopython mode backend)
none cannot be represented as a Numpy dtype

然后,我尝试将t设置为指向xy的函数指针数组。

代码语言:javascript
复制
from numba import float64, intp, void, deferred_type
tt = deferred_type()
t = void(tt, intp)

来自'ptrs': t[:]})的错误是:

代码语言:javascript
复制
Traceback (most recent call last):
  File "<stdin>", line 7, in <module>
TypeError: 'Signature' object is not subscriptable

最后,我尝试了t = void。这给了t = Test()行编译过程中的错误。

代码语言:javascript
复制
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\experimental\jitclass\base.py", line 122, in __call__
    return cls._ctor(*bind.args[1:], **bind.kwargs)
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\dispatcher.py", line 420, in _compile_for_args
    error_rewrite(e, 'typing')
  File "C:\Users\madphysicist\Anaconda3\lib\site-packages\numba\core\dispatcher.py", line 361, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Failed in nopython mode pipeline (step: nopython frontend)
NameError: name 'np' is not defined
During: resolving callee type: jitclass.Test#1bbda3d32e0<a:float64,ptrs:array(none, 1d, A)>
During: typing of call at <string> (3)

During: resolving callee type: jitclass.Test#1bbda3d32e0<a:float64,ptrs:array(none, 1d, A)>
During: typing of call at <string> (3)


File "<string>", line 3:
<source missing, REPL/exec in use?>

那么,我如何创建一个函数指针数组呢?

EN

回答 1

Stack Overflow用户

发布于 2021-12-07 23:53:53

还不支持将Numba函数存储到Numpy数组中,而是。至少在0.54.1版本(和以前的版本)中,没有处于nopython模式。

一种实验检查方法是简单地编辑文档代码并将结果存储在数组中,然后让Numba推断相应的类型。Numba失败的原因是不支持数组的dtype。

另一种检查方法是读取存储库中的support.py代码。此文件包含当前支持的类型。我列出了以下内容:

代码语言:javascript
复制
Basic types:
    - bool
    - int8
    - int16
    - int32
    - int64
    - uint8
    - uint16
    - uint32
    - uint64
    - float32
    - float64
    - complex64
    - complex128

Advanced types:
    - Datetime
    - Timedelta
    - CharSeq (mark as 'S' in Numpy)
    - UnicodeCharSeq (mark as 'U' in Numpy)
    - Numpy structured types (containing other ones like Numpy does)

Generic type:
    - object (unsupported in nopython mode since the type inference cannot work with it)

Other (unknown) types:
    - EnumMember
    - NumberClass
    - NestedArray
票数 1
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/70242521

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档