Matt Billman
Matt Billman

Reputation: 31

Require overriding method to call super()

I'd like to force certain methods in child classes to call to invoke the method they're overriding.

@abstractmethod can require certain methods be implemented; I'd like a behavior similar to this (i.e., if the overriding method doesn't call super(), don't execute and complain to the user).

Example:

class Foo:
    @must_call_super
    def i_do_things(self):
        print('called')

class Good(Foo):
    def i_do_things(self):
        # super().i_do_things() is called; will run.
        super().i_do_things()
        print('called as well')

class Bad(Foo):
    def i_do_things(self):
        # should complain that super().i_do_things isn't called here
        print('called as well')

# should work fine
good = Good()

# should error
bad = Bad()

Upvotes: 2

Views: 42

Answers (1)

Fanchen Bao
Fanchen Bao

Reputation: 4289

Thanks for sending me down the rabbit hole.

Below is my solution to this problem. It uses metaclass, ast, and some hacking to detect whether a child class calls super().some_func() in its version of some_func method.

Core classes

These should be controlled by the developer.

import inspect
import ast
import textwrap


class Analyzer(ast.NodeVisitor):
    def __init__(self, ast_sig: str):
        self.func_exists = False
        self.sig = ast_sig

    def visit_Call(self, node):
        """Traverse the ast tree. Once a node's signature matches the given
        method call's signature, we consider that the method call exists.
        """
        # print(ast.dump(node))
        if ast.dump(node) == self.sig:
            self.func_exists |= True
        self.generic_visit(node)


class FooMeta(type):
    # _ast_sig_super_methods stores the ast signature of any method that
    # a `super().method()` call must be made in its overridden version in an
    # inherited child. One can add more method and its associted ast sig in
    # this dict.
    _ast_sig_super_methods = {
        'i_do_things': "Call(func=Attribute(value=Call(func=Name(id='super', ctx=Load()), args=[], keywords=[]), attr='i_do_things', ctx=Load()), args=[], keywords=[])",
    }

    def __new__(cls, name, bases, dct):
        # cls = FooMeta
        # name = current class name
        # bases = any parents of the current class
        # dct = namespace dict of the current class
        for method, ast_sig in FooMeta._ast_sig_super_methods.items():
            if name != 'Foo' and method in dct:  # desired method in subclass
                source = inspect.getsource(dct[method])  # get source code
                formatted_source = textwrap.dedent(source)  # correct indentation
                tree = ast.parse(formatted_source)  # obtain ast tree
                analyzer = Analyzer(ast_sig)
                analyzer.visit(tree)
                if not analyzer.func_exists:
                    raise RuntimeError(f'super().{method} is not called in {name}.{method}!')
        return super().__new__(cls, name, bases, dct)


class Foo(metaclass=FooMeta):
    def i_do_things(self):
        print('called')

Usage and Effect

This is done by other people, from whom we want to dictate that super().i_do_things must be called in the overridden version in their inherited classes.

Good

class Good(Foo):
    def i_do_things(self):
        # super().i_do_things() is called; will run.
        super().i_do_things()
        print('called as well')

good = Good()
good.i_do_things()

# output:
# called
# called as well

Bad

class Bad(Foo):
    def i_do_things(self):
        # should complain that super().i_do_things isn't called here
        print('called as well')

# Error output:
# RuntimeError: super().i_do_things is not called in Bad.i_do_things!

Secretly Bad

class Good(Foo):
    def i_do_things(self):
        # super().i_do_things() is called; will run.
        super().i_do_things()
        print('called as well')


class SecretlyBad(Good):
    def i_do_things(self):
        # also shall complain super().i_do_things isn't called
        print('called as well')


# Error output:
# RuntimeError: super().i_do_things is not called in SecretlyBad.i_do_things!

Note

  1. Since FooMeta is executed when the inherited classes are defined, not when they are instantiated, error is thrown before Bad().i_do_things() or SecretlyBad().i_do_things() is called. This is not the same as the requirement by the OP, but it does achieve the same end goal.
  2. To obtain the ast signature of super().i_do_things(), we can uncomment the print statement in Analyzer, analyze the source code of Good.i_do_things, and inspect from there.

Upvotes: 1

Related Questions