jcai
jcai

Reputation: 3593

Downcast to pybind11 derived class

I'm using the "Overriding virtual functions in Python" feature of pybind11 to create Python classes that inherit from C++ abstract classes. I have a C++ class State which is subclassed in Python as MyState. In this situation I have some MyState object that lost its type information and Python thinks it's a State. I need to downcast it back to MyState in Python code and I don't know a good way to do this.

Here's the C++ example code:

#include <memory>

#include <pybind11/pybind11.h>

namespace py = pybind11;

// ========== State ==========

class State {
 public:
  virtual ~State() = default;

  virtual void dump() = 0;
};

using StatePtr = std::shared_ptr<State>;

class PyState : public State {
 public:
  using State::State;

  void dump() override {
    PYBIND11_OVERLOAD_PURE(void, State, dump);
  }
};

// ========== Machine ==========

class Machine {
 public:
  virtual ~Machine() = default;

  virtual StatePtr begin() = 0;

  virtual StatePtr step(const StatePtr&) = 0;
};

using MachinePtr = std::shared_ptr<Machine>;

class PyMachine : public Machine {
 public:
  using Machine::Machine;

  StatePtr begin() override {
    PYBIND11_OVERLOAD_PURE(StatePtr, Machine, begin);
  }

  StatePtr step(const StatePtr& state) override {
    PYBIND11_OVERLOAD_PURE(StatePtr, Machine, step, state);
  }
};

// ========== run ==========

void run(const MachinePtr& machine) {
  StatePtr state = machine->begin();
  for (int i = 0; i < 5; ++i) {
    state = machine->step(state);
    state->dump();
  }
}

// ========== pybind11 ==========

PYBIND11_MODULE(example, m) {
  py::class_<State, StatePtr, PyState>(m, "State").def(py::init<>());
  py::class_<Machine, MachinePtr, PyMachine>(m, "Machine")
      .def(py::init<>())
      .def("begin", &Machine::begin)
      .def("step", &Machine::step);
  m.def("run", &run, "Run the machine");
}

And the Python code:

#!/usr/bin/env python3

from example import Machine, State, run


class MyState(State):
    def __init__(self, x):
        State.__init__(self)
        self.x = x

    def dump(self):
        print(self.x)


class MyMachine(Machine):
    def __init__(self):
        Machine.__init__(self)

    def begin(self):
        return MyState(0)

    def step(self, state):
        # problem: when called from C++, `state` is an `example.State`
        # instead of `MyState`. In order to access `state.x` we need
        # some way to downcast it...
        return MyState(state.x + 1)


machine = MyMachine()

print("running machine with python")
state = machine.begin()
for _ in range(5):
    state = machine.step(state)
    state.dump()

print("running machine with C++")
run(machine)  # error

Error message:

running machine with python
1
2
3
4
5
running machine with C++
Traceback (most recent call last):
  File "<string>", line 38, in <module>
  File "<string>", line 36, in __run
  File "/usr/local/fbcode/platform007/lib/python3.6/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/local/fbcode/platform007/lib/python3.6/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/data/users/jcai/fbsource/fbcode/buck-out/dev/gen/experimental/jcai/pybind/run_example#link-tree/run_example.py", line 38, in <module>
    run(machine)  # error
  File "/data/users/jcai/fbsource/fbcode/buck-out/dev/gen/experimental/jcai/pybind/run_example#link-tree/run_example.py", line 26, in step
    return MyState(state.x + 1)
AttributeError: 'example.State' object has no attribute 'x'

I do have a hacky workaround where I basically keep a "downcast map" std::unordered_map<State*, py::object> and register every created MyState with it. But I prefer not to resort to such things.

Upvotes: 2

Views: 1491

Answers (1)

eacousineau
eacousineau

Reputation: 3533

I think that you're probably suffering from this suite of problems:

https://github.com/pybind/pybind11/issues/1774

Ultimately, because you're just returning MyState straight out the gate, which then goes straight to C++, the Python interpreter loses track of your instance, and goes ahead and garbage collects the Python-portion of the object, which is why your object ends up getting kinda sliced.

Potential solutions:

  • Stash a reference to your return MyState,at least long enough for the Python interpreter to get a reference again.
    • e.g. change return MyState(...) to self._stashed_state = MyState(...); return self._stashed_state
  • See if you can somehow incref on the Python version of your class in C++ (yuck, but it'll work)
  • Review the workarounds listed in the aforementioned issues (can't remember all of 'em)
  • Use our fork of pybind11, which handles this, but also drags in other stuff: overview for RobotLocomotion/pybind11

You may also want to post on one of the existing issues mentioning that you encountered this problem as well (just so that it can be tracked).

Upvotes: 1

Related Questions