首页
学习
活动
专区
圈层
工具
发布
社区首页 >问答首页 >如何使用pyo3将锈蚀函数作为回调传递给Python

如何使用pyo3将锈蚀函数作为回调传递给Python
EN

Stack Overflow用户
提问于 2022-03-04 21:34:17
回答 1查看 546关注 0票数 1

我使用Pyo3从Python调用Rust函数,反之亦然。

我正努力做到以下几点:

rust_function_1调用

  • Python

  • 锈蚀函数rust_function_1调用Python函数python_function,将锈蚀函数rust_function_2作为回调参数

传递

  • Python函数python_function调用回调,在本例中是Rust函数python_function

我不知道如何将rust_function_2作为回调参数传递给python_function

我有以下Python代码:

代码语言:javascript
复制
import rust_module

def python_function(callback):
    print("This is python_function")
    callback()

if __name__ == '__main__':
    rust_module.rust_function_1()

我有以下未编译的Rust代码:

代码语言:javascript
复制
use pyo3::prelude::*;

#[pyfunction]
fn rust_function_1() -> PyResult<()> {
    println!("This is rust_function_1");
    Python::with_gil(|py| {
        let python_module = PyModule::import(py, "python_module")?;
        python_module
            .getattr("python_function")?
            .call1((rust_function_2.into_py(py),))?;  // Compile error
        Ok(())
    })
}

#[pyfunction]
fn rust_function_2() -> PyResult<()> {
    println!("This is rust_function_2");
    Ok(())
}

#[pymodule]
#[pyo3(name = "rust_module")]
fn quantum_network_stack(_python: Python, module: &PyModule) -> PyResult<()> {
    module.add_function(wrap_pyfunction!(rust_function_1, module)?)?;
    module.add_function(wrap_pyfunction!(rust_function_2, module)?)?;
    Ok(())
}

错误信息是:

代码语言:javascript
复制
error[E0599]: the method `into_py` exists for fn item `fn() -> Result<(), PyErr> {rust_function_2}`, but its trait bounds were not satisfied
  --> src/lib.rs:10:37
   |
10 |             .call1((rust_function_2.into_py(py),))?;
   |                                     ^^^^^^^ method cannot be called on `fn() -> Result<(), PyErr> {rust_function_2}` due to unsatisfied trait bounds
   |
   = note: `rust_function_2` is a function, perhaps you wish to call it
   = note: the following trait bounds were not satisfied:
           `fn() -> Result<(), PyErr> {rust_function_2}: AsPyPointer`
           which is required by `&fn() -> Result<(), PyErr> {rust_function_2}: pyo3::IntoPy<Py<PyAny>>`
EN

回答 1

Stack Overflow用户

回答已采纳

发布于 2022-03-05 15:08:52

PitaJ的评论让我找到了解决方案。

起作用的锈蚀代码:

代码语言:javascript
复制
use pyo3::prelude::*;

#[pyclass]
struct Callback {
    #[allow(dead_code)] // callback_function is called from Python
    callback_function: fn() -> PyResult<()>,
}

#[pymethods]
impl Callback {
    fn __call__(&self) -> PyResult<()> {
        (self.callback_function)()
    }
}

#[pyfunction]
fn rust_function_1() -> PyResult<()> {
    println!("This is rust_function_1");
    Python::with_gil(|py| {
        let python_module = PyModule::import(py, "python_module")?;
        let callback = Box::new(Callback {
            callback_function: rust_function_2,
        });
        python_module
            .getattr("python_function")?
            .call1((callback.into_py(py),))?;
        Ok(())
    })
}

#[pyfunction]
fn rust_function_2() -> PyResult<()> {
    println!("This is rust_function_2");
    Ok(())
}

#[pymodule]
#[pyo3(name = "rust_module")]
fn quantum_network_stack(_python: Python, module: &PyModule) -> PyResult<()> {
    module.add_function(wrap_pyfunction!(rust_function_1, module)?)?;
    module.add_function(wrap_pyfunction!(rust_function_2, module)?)?;
    module.add_class::<Callback>()?;
    Ok(())
}

工作的Python代码(与问题相同):

代码语言:javascript
复制
import rust_module

def python_function(callback):
    print("This is python_function")
    callback()

if __name__ == '__main__':
    rust_module.rust_function_1()

以下解决方案以多种方式对上述孤岛进行了改进:

  • 由Rust提供的callback被存储并在稍后调用,而不是立即被调用(对于现实的用例来说这更现实)

每次调用Rust时,它都会传入一个PythonApi对象,从而消除了每当调用Rust函数时都需要使用Rust函数执行Python import

  • 由Rust提供的回调可以是闭包,除了普通函数之外,还可以捕获变量(仅移动语义)。

更通用的锈菌代码如下:

代码语言:javascript
复制
use pyo3::prelude::*;

#[pyclass]
struct Callback {
    #[allow(dead_code)] // callback_function is called from Python
    callback_function: Box<dyn Fn(&PyAny) -> PyResult<()> + Send>,
}

#[pymethods]
impl Callback {
    fn __call__(&self, python_api: &PyAny) -> PyResult<()> {
        (self.callback_function)(python_api)
    }
}

#[pyfunction]
fn rust_register_callback(python_api: &PyAny) -> PyResult<()> {
    println!("This is rust_register_callback");
    let message: String = "a captured variable".to_string();
    Python::with_gil(|py| {
        let callback = Box::new(Callback {
            callback_function: Box::new(move |python_api| {
                rust_callback(python_api, message.clone())
            }),
        });
        python_api
            .getattr("set_callback")?
            .call1((callback.into_py(py),))?;
        Ok(())
    })
}

#[pyfunction]
fn rust_callback(python_api: &PyAny, message: String) -> PyResult<()> {
    println!("This is rust_callback");
    println!("Message = {}", message);
    python_api.getattr("some_operation")?.call0()?;
    Ok(())
}

#[pymodule]
#[pyo3(name = "rust_module")]
fn quantum_network_stack(_python: Python, module: &PyModule) -> PyResult<()> {
    module.add_function(wrap_pyfunction!(rust_register_callback, module)?)?;
    module.add_function(wrap_pyfunction!(rust_callback, module)?)?;
    module.add_class::<Callback>()?;
    Ok(())
}

更通用的Python代码如下:

代码语言:javascript
复制
import rust_module


class PythonApi:

    def __init__(self):
        self.callback = None

    def set_callback(self, callback):
        print("This is PythonApi::set_callback")
        self.callback = callback

    def call_callback(self):
        print("This is PythonApi::call_callback")
        assert self.callback is not None
        self.callback(self)

    def some_operation(self):
        print("This is PythonApi::some_operation")

def python_function(python_api, callback):
    print("This is python_function")
    python_api.callback = callback


def main():
    print("This is main")
    python_api = PythonApi()
    print("Calling rust_register_callback")
    rust_module.rust_register_callback(python_api)
    print("Returned from rust_register_callback; back in main")
    print("Calling callback")
    python_api.call_callback()


if __name__ == '__main__':
    main()

后一版本的代码的输出如下:

代码语言:javascript
复制
This is main
Calling rust_register_callback
This is rust_register_callback
This is PythonApi::set_callback
Returned from rust_register_callback; back in main
Calling callback
This is PythonApi::call_callback
This is rust_callback
Message = a captured variable
This is PythonApi::some_operation
票数 2
EN
页面原文内容由Stack Overflow提供。腾讯云小微IT领域专用引擎提供翻译支持
原文链接:

https://stackoverflow.com/questions/71357427

复制
相关文章

相似问题

领券
问题归档专栏文章快讯文章归档关键词归档开发者手册归档开发者手册 Section 归档