Mahi
Mahi

Reputation: 21932

How to type annotate overrided methods in a subclass?

Say I already have a method with type annotations:

class Shape:
    def area(self) -> float:
        raise NotImplementedError

Which I will then subclass multiple times:

class Circle:
    def area(self) -> float:
        return math.pi * self.radius ** 2

class Rectangle:
    def area(self) -> float:
        return self.height * self.width

As you can see, I'm duplicating the -> float quite a lot. Say I have 10 different shapes, with multiple methods like this, some of which contain parameters too. Is there a way to just "copy" the annotation from the parent class, similar to what functools.wraps() does with docstrings?

Upvotes: 6

Views: 1179

Answers (2)

Sede
Sede

Reputation: 61253

You can use a class decorator to update your subclass methods annotations. In your decorator you will need to walk through your class definition then update only those methods that are present in your superclass. Of course to access the superclass you need to use the it __mro__ which is just the tuple of the class, subclass, till object. Here we are interested in the second element in that tuple which is at index 1 thus __mro__[1] or using the cls.mro()[1]. Last and not least your decorator must return the class.

def wraps_annotations(cls):
    mro = cls.mro()[1] 
    vars_mro = vars(mro)
    for name, value in vars(cls).items():
        if callable(value) and name in vars_mro:
            value.__annotations__.update(vars(mro).get(name).__annotations__)
    return cls

Demo:

>>> class Shape:
...     def area(self) -> float:
...         raise NotImplementedError
...
>>> import math
>>>
>>> @wraps_annotations
... class Circle(Shape):
...     def area(self):
...         return math.pi * self.radius ** 2
...
>>> c = Circle()
>>> c.area.__annotations__
{'return': <class 'float'>}
>>> @wraps_annotations
... class Rectangle(Shape):
...     def area(self):
...         return self.height * self.width
...
>>> r = Rectangle()
>>> r.area.__annotations__
{'return': <class 'float'>}

Upvotes: 0

Ilja Everil&#228;
Ilja Everil&#228;

Reputation: 52949

This might work, though I'm sure to miss the edge cases, like additional arguments:

from functools import partial, update_wrapper


def annotate_from(f):
    return partial(update_wrapper,
                   wrapped=f,
                   assigned=('__annotations__',),
                   updated=())

which will assign "wrapper" function's __annotations__ attribute from f.__annotations__ (keep in mind that it is not a copy).

According to documents the update_wrapper function's default for assigned includes __annotations__ already, but I can see why you'd not want to have all the other attributes assigned from wrapped.

With this you can then define your Circle and Rectangle as

class Circle:
    @annotate_from(Shape.area)
    def area(self):
        return math.pi * self.radius ** 2

class Rectangle:
    @annotate_from(Shape.area)
    def area(self):
        return self.height * self.width

and the result

In [82]: Circle.area.__annotations__
Out[82]: {'return': builtins.float}

In [86]: Rectangle.area.__annotations__
Out[86]: {'return': builtins.float}

As a side effect your methods will have an attribute __wrapped__, which will point to Shape.area in this case.


A less standard (if you can call the above use of update_wrapper standard) way to accomplish handling of overridden methods can be achieved using a class decorator:

from inspect import getmembers, isfunction, signature


def override(f):
    """
    Mark method overrides.
    """
    f.__override__ = True
    return f


def _is_method_override(m):
    return isfunction(m) and getattr(m, '__override__', False)


def annotate_overrides(cls):
    """
    Copy annotations of overridden methods.
    """
    bases = cls.mro()[1:]
    for name, method in getmembers(cls, _is_method_override):
        for base in bases:
            if hasattr(base, name):
                break

        else:
            raise RuntimeError(
                    'method {!r} not found in bases of {!r}'.format(
                            name, cls))

        base_method = getattr(base, name)
        method.__annotations__ = base_method.__annotations__.copy()

    return cls

and then:

@annotate_overrides
class Rectangle(Shape):
    @override
    def area(self):
        return self.height * self.width

Again, this will not handle overriding methods with additional arguments.

Upvotes: 4

Related Questions