Azat Ibrakov
Azat Ibrakov

Reputation: 10990

Instantiate a child in __new__ with different __new__ signature for a child

Preface

I want to have 2 classes Interval and Segment with the following properties:

  1. Interval can have start & end points, any of them can be included/excluded (I've implemented this using required flag parameters like start_inclusive/end_inclusive).
  2. Segment is an Interval with both endpoints included, so user do not need to specify those flags.
  3. If user is trying to create Interval with endpoints included he gets a Segment like

    >>> Interval(0, 1, start_inclusive=True, end_inclusive=True)
    Segment(0, 1)
    

    (this doesn't look impossible)

Problem

My MCVE implementation so far is

Interval class:

class Interval:
    def __new__(cls, start: int, end: int,
                *,
                start_inclusive: bool,
                end_inclusive: bool) -> 'Interval':
        if cls is not __class__:
            return super().__new__(cls)
        if start == end:
            raise ValueError('Degenerate interval found.')
        if start_inclusive and end_inclusive:
            return Segment(start, end)
        return super().__new__(cls)

    def __init__(self,
                 start: int,
                 end: int,
                 *,
                 start_inclusive: bool,
                 end_inclusive: bool) -> None:
        self.start = start
        self.end = end
        self.start_inclusive = start_inclusive
        self.end_inclusive = end_inclusive

Segment class:

class Segment(Interval):
    def __new__(cls, start: int, end: int) -> 'Interval':
        return super().__new__(cls, start, end,
                               start_inclusive=True,
                               end_inclusive=True)

    def __init__(self, start: int, end: int) -> None:
        super().__init__(start, end,
                         start_inclusive=True,
                         end_inclusive=True)

Creation kinda works

>>> Interval(0, 1, start_inclusive=False, end_inclusive=True)
<__main__.Interval object at ...>
>>> Interval(0, 1, start_inclusive=False, end_inclusive=False)
<__main__.Interval object at ...>
>>> Segment(0, 1)
<__main__.Segment object at ...>

but

>>> Interval(0, 1, start_inclusive=True, end_inclusive=True)

fails with following TypeError

Traceback (most recent call last):
  File "<input>", line 1, in <module>
TypeError: __init__() got an unexpected keyword argument 'end_inclusive'

So my question is:

Is there any idiomatic way of instantiating child class in parent's __new__ with some parameters of __new__ & __init__ "bound" by a child?

Upvotes: 5

Views: 1071

Answers (2)

javidcf
javidcf

Reputation: 59731

You can solve that with a metaclass to customize when __init__ is called after __new__:

class IntervalMeta(type):
    def __call__(cls, *args, **kwargs):
        obj = cls.__new__(cls, *args, **kwargs)
        # Only call __init__ if class of object is exactly this class
        if type(obj) is cls:
            cls.__init__(obj, *args, **kwargs)
        # As opposed to default behaviour:
        # if isinstance(obj, cls):
        #     type(obj).__init__(obj, *args, **kwargs)
        return obj

# Code below does not change except for metaclass
class Interval(metaclass=IntervalMeta):
    def __new__(cls, start: int, end: int,
                *,
                start_inclusive: bool,
                end_inclusive: bool) -> 'Interval':
        if cls is not __class__:
            return super().__new__(cls)
        if start == end:
            raise ValueError('Degenerate interval found.')
        if start_inclusive and end_inclusive:
            return Segment(start, end)
        return super().__new__(cls)

    def __init__(self,
                 start: int,
                 end: int,
                 *,
                 start_inclusive: bool,
                 end_inclusive: bool) -> None:
        self.start = start
        self.end = end
        self.start_inclusive = start_inclusive
        self.end_inclusive = end_inclusive

class Segment(Interval):
    def __new__(cls, start: int, end: int) -> 'Interval':
        return super().__new__(cls, start, end,
                               start_inclusive=True,
                               end_inclusive=True)

    def __init__(self, start: int, end: int) -> None:
        super().__init__(start, end,
                         start_inclusive=True,
                         end_inclusive=True)

print(Interval(0, 1, start_inclusive=True, end_inclusive=True))
# <__main__.Segment object at ...>

Upvotes: 2

Mad Physicist
Mad Physicist

Reputation: 114440

Let's look at why you get the error first. When you call a class derived from object, the __call__ method of the metaclass (type) is called. That usually goes something like

self = cls.__new__(...)
if isinstance(self, cls):
    type(self).__init__(self)

This is only approximate, but enough to convey what is happening here:

  1. type.__call__ calls Interval.__new__
  2. Since start_inclusive and end_inclusive, Interval.__new__ correctly returns an instance of Segment
  3. Since issubclass(Segment, Interval), type.__call__ calls Segment.__init__ with all the parameters that you had passed to the call to Interval
  4. Segment.__init__ does not accept any keyword parameters, and raises the error you see.

There are a number of workarounds to this situation. @jdehesa's answer shows how to override the behavior of type so that type.__call__ checks type(obj) is cls instead of using isinstance.

Another alternative would be to dissociate the hierarchy of Interval and Segment. You could do something like

class MyBase:
    # put common functionality here

class Interval(MyBase):
    # __new__ and __init__ same as before

class Segment(MyBase):
    # __new__ and __init__ same as before

With this arrangement isinstance(Segment(...), Interval) will be False, and type.__call__ will not attempt to call Interval.__init__ on a Segment.

The simplest way to do this, in my opinion, would be to use a factory pattern. Have an external function that determines what type of object to return based on the input. That way, you do not need to implement __new__ at all, and your class construction process will be much simpler:

def factory(start, end, *, start_inclusive, end_inclusive):
    if start_inclusive and end_inclusive:
        return Segment(start, end)
    return Interval(start, end, start_inclusive=start_inclusive, end_inclusive=end_inclusive)

Upvotes: 3

Related Questions