Reputation: 3807
I am using Pyo3 to call Rust functions from Python and vice versa.
I am trying to achieve the following:
Python calls rust_function_1
Rust function rust_function_1
calls Python function python_function
passing Rust function rust_function_2
as a callback argument
Python function python_function
calls the callback, which in this case is Rust function rust_function_2
I cannot figure out how to pass rust_function_2
as a callback argument to python_function
.
I have the following Python code:
import rust_module
def python_function(callback):
print("This is python_function")
callback()
if __name__ == '__main__':
rust_module.rust_function_1()
And I have the following non-compiling Rust code:
use pyo3::prelude::*;
#[pyfunction]
fn rust_function_1() -> PyResult<()> {
println!("This is rust_function_1");
Python::with_gil(|py| {
let python_module = PyModule::import(py, "python_module")?;
python_module
.getattr("python_function")?
.call1((rust_function_2.into_py(py),))?; // Compile error
Ok(())
})
}
#[pyfunction]
fn rust_function_2() -> PyResult<()> {
println!("This is rust_function_2");
Ok(())
}
#[pymodule]
#[pyo3(name = "rust_module")]
fn quantum_network_stack(_python: Python, module: &PyModule) -> PyResult<()> {
module.add_function(wrap_pyfunction!(rust_function_1, module)?)?;
module.add_function(wrap_pyfunction!(rust_function_2, module)?)?;
Ok(())
}
The error message is:
error[E0599]: the method `into_py` exists for fn item `fn() -> Result<(), PyErr> {rust_function_2}`, but its trait bounds were not satisfied
--> src/lib.rs:10:37
|
10 | .call1((rust_function_2.into_py(py),))?;
| ^^^^^^^ method cannot be called on `fn() -> Result<(), PyErr> {rust_function_2}` due to unsatisfied trait bounds
|
= note: `rust_function_2` is a function, perhaps you wish to call it
= note: the following trait bounds were not satisfied:
`fn() -> Result<(), PyErr> {rust_function_2}: AsPyPointer`
which is required by `&fn() -> Result<(), PyErr> {rust_function_2}: pyo3::IntoPy<Py<PyAny>>`
Upvotes: 4
Views: 1661
Reputation: 3807
The comment from PitaJ led me to the solution.
Rust code that works:
use pyo3::prelude::*;
#[pyclass]
struct Callback {
#[allow(dead_code)] // callback_function is called from Python
callback_function: fn() -> PyResult<()>,
}
#[pymethods]
impl Callback {
fn __call__(&self) -> PyResult<()> {
(self.callback_function)()
}
}
#[pyfunction]
fn rust_function_1() -> PyResult<()> {
println!("This is rust_function_1");
Python::with_gil(|py| {
let python_module = PyModule::import(py, "python_module")?;
let callback = Box::new(Callback {
callback_function: rust_function_2,
});
python_module
.getattr("python_function")?
.call1((callback.into_py(py),))?;
Ok(())
})
}
#[pyfunction]
fn rust_function_2() -> PyResult<()> {
println!("This is rust_function_2");
Ok(())
}
#[pymodule]
#[pyo3(name = "rust_module")]
fn quantum_network_stack(_python: Python, module: &PyModule) -> PyResult<()> {
module.add_function(wrap_pyfunction!(rust_function_1, module)?)?;
module.add_function(wrap_pyfunction!(rust_function_2, module)?)?;
module.add_class::<Callback>()?;
Ok(())
}
Python code that works (same as in the question):
import rust_module
def python_function(callback):
print("This is python_function")
callback()
if __name__ == '__main__':
rust_module.rust_function_1()
The following solution improves on the above solition in a number of ways:
The callback
provided by Rust is stored and called later, instead of being called immediately (this is more realistic for real-life use cases)
Each time when Python calls Rust, it passes in a PythonApi
object removes the need for Rust function to do a Python import
every time they are called.
The callback provided by Rust can be closures that capture variables (move semantics only) in addition to plain functions.
The more general Rust code is as follows:
use pyo3::prelude::*;
#[pyclass]
struct Callback {
#[allow(dead_code)] // callback_function is called from Python
callback_function: Box<dyn Fn(&PyAny) -> PyResult<()> + Send>,
}
#[pymethods]
impl Callback {
fn __call__(&self, python_api: &PyAny) -> PyResult<()> {
(self.callback_function)(python_api)
}
}
#[pyfunction]
fn rust_register_callback(python_api: &PyAny) -> PyResult<()> {
println!("This is rust_register_callback");
let message: String = "a captured variable".to_string();
Python::with_gil(|py| {
let callback = Box::new(Callback {
callback_function: Box::new(move |python_api| {
rust_callback(python_api, message.clone())
}),
});
python_api
.getattr("set_callback")?
.call1((callback.into_py(py),))?;
Ok(())
})
}
#[pyfunction]
fn rust_callback(python_api: &PyAny, message: String) -> PyResult<()> {
println!("This is rust_callback");
println!("Message = {}", message);
python_api.getattr("some_operation")?.call0()?;
Ok(())
}
#[pymodule]
#[pyo3(name = "rust_module")]
fn quantum_network_stack(_python: Python, module: &PyModule) -> PyResult<()> {
module.add_function(wrap_pyfunction!(rust_register_callback, module)?)?;
module.add_function(wrap_pyfunction!(rust_callback, module)?)?;
module.add_class::<Callback>()?;
Ok(())
}
The more general Python code is as follows:
import rust_module
class PythonApi:
def __init__(self):
self.callback = None
def set_callback(self, callback):
print("This is PythonApi::set_callback")
self.callback = callback
def call_callback(self):
print("This is PythonApi::call_callback")
assert self.callback is not None
self.callback(self)
def some_operation(self):
print("This is PythonApi::some_operation")
def python_function(python_api, callback):
print("This is python_function")
python_api.callback = callback
def main():
print("This is main")
python_api = PythonApi()
print("Calling rust_register_callback")
rust_module.rust_register_callback(python_api)
print("Returned from rust_register_callback; back in main")
print("Calling callback")
python_api.call_callback()
if __name__ == '__main__':
main()
The output from the latter version of code is as follows:
This is main
Calling rust_register_callback
This is rust_register_callback
This is PythonApi::set_callback
Returned from rust_register_callback; back in main
Calling callback
This is PythonApi::call_callback
This is rust_callback
Message = a captured variable
This is PythonApi::some_operation
Upvotes: 6