Arya McCarthy
Arya McCarthy

Reputation: 8829

Template classes in Python

I'm attempting to create algebras in Python, and I'm finding it difficult to create parameterized classes.

As an example, consider this ProductWeight class. It holds two other Weight objects, with enforced types (at least statically, via mypy).

This will fail because (by design) I cannot access the classes for W1 and W2 to call their classmethods like zero. (They're not specified; ProductWeight is not templated.) ProductWeight doesn't know, when I create an instance, which types to bind to it.

from typing import Generic, TypeVar
W1 = TypeVar("W1")
W2 = TypeVar("W2")

class ProductWeight(Generic[W1, W2]):
    def __init__(self, value1: W1, value2: W2):
        self.value1_ = value1
        self.value2_ = value2

    @classmethod
    def zero(cls):
        return cls(W1.zero(), W2.zero())  # Will fail - no access to W1 and W2.
    

By contrast, this is straightforward in C++: because the type is parameterized, it is able to look up W1::Zero.

template<typename W1, typename W2>
public:
    ProductWeight(W1 w1, W2 w2) : value1_(w1), value2_(w2) {}
    static ProductWeight Zero() {
        return ProductWeight(W1::Zero(), W2::Zero());
    }
private:
    W1 value1_;
    W2 value2_;
};

Is there a workaround for this in Python? Either creating an inner class, or otherwise somehow providing types to the class (rather than to class instances)?

For the sake of a minimal reproducible example, you can use this implementation of another Weight type.

class SimpleWeight:
    def __init__(self, value):
        assert value in {True, False}
        self.value_ = value

    @classmethod
    def zero(cls):
        return cls(False)

weight1 = SimpleWeight(False)
weight2 = SimpleWeight(True)
product = ProductWeight(weight1, weight2)

print(SimpleWeight.zero())  # So far, so good.
print(ProductWeight.zero())  # Oof.
# Predictably, it failed because `ProductWeight` is not specialized.

And here's the predictable error message:

Traceback (most recent call last):
  File "garbage.py", line 32, in <module>
    print(ProductWeight.zero())  # Oof.
  File "garbage.py", line 14, in zero
    return cls(W1.zero(), W2.zero())  # Will fail - no access to W1 and W2.
AttributeError: 'TypeVar' object has no attribute 'zero'

Ideally, it would be possible to create a parametric type like this:

product = ProductWeight[SimpleWeight, SimpleWeight](weight1, weight2)
# And similarly:
print(ProductWeight[SimpleWeight, SimpleWeight].zero())

Upvotes: 5

Views: 15073

Answers (1)

sanitizedUser
sanitizedUser

Reputation: 2105

To solve this, it's paramount to realize that generics in C++ create multiple classes while in Python you will always have only one class, no matter what types the parameters in its constructor have.

In other words, in C++ vector<int> and vector<string> are two classes. If you bound them to Python interpreter you would have to assign them two different names, for example VectorInt and VectorString.

The code

def ProductWeight(value1, value2):
    def init(self, value1, value2):
        self.value1 = value1
        self.value2 = value2

    def zero(cls):
        return cls(cls.W1.zero(), cls.W2.zero())

    W1 = type(value1)
    W2 = type(value2)
    name = f'ProductWeight{W1}{W2}'

    try:
        return ProductWeight.types[name](value1, value2)
    except KeyError:
        pass

    cls = type(name, (), {'__init__': init})
    cls.W1 = W1
    cls.W2 = W2
    cls.zero = classmethod(zero)
    ProductWeight.types[name] = cls

    return cls(value1, value2)

ProductWeight.types = {}

class SimpleWeight:
    def __init__(self, value):
        assert value in {True, False}
        self.value_ = value

    @classmethod
    def zero(cls):
        return cls(False)

weight1 = SimpleWeight(False)
weight2 = SimpleWeight(True)
product = ProductWeight(weight1, weight2)

print(SimpleWeight.zero())
print(type(product).zero())

The code tries to respect your original API, you may create instances of class ProductWeight with generic parameters. However, you have to access an underlying type of an instance if you want to call its classmethod zero (notice the change from ProductWeight to type(product) at the last line).

You may save this reference to a variable for convenience.

The function ProductWeight serves as a generic factory. Each time you call it, it creates a name for the new class based on the types of the parameters. If such class already exists, it just returns a new instance of it. Otherwise, it creates the new class using the type function and then returns a new instance.

ProductWeight on its own is also a singleton object that has a dictionary of already created types.

Conclusion

You may notice that this solution uses significantly more memory than its C++ counterpart. However, given that you chose to use Python instead of C++, you will probably not worry about that too much.

More importantly, you will have to decide if this is the right path for you to take. Remember, in Python there are no generics, there are only "dynamics". So your way of thinking will bring in more obstacles than it will remove in the long run.

Making Python sound like C++

This part answers this edit to your question:

Ideally, it would be possible to create a parametric type like this:

product = ProductWeight[SimpleWeight, SimpleWeight](weight1, weight2)
# And similarly:
print(ProductWeight[SimpleWeight, SimpleWeight].zero())

Fear not, because this is actually possible. The following code uses __class_getitem__ from Python 3.7, however with a workaround it can be run on older versions as well, see this question about static getitem method.

Code rewritten

# SimpleWeight didn't change
from simple_weight import SimpleWeight

class ProductWeight:
    types = {}

    def __class_getitem__(cls, key):
        try:
            W1, W2 = key
        except ValueError:
            raise Exception('ProductWeight[] takes exactly two arguments.')

        name = f'{ProductWeight.__name__}<{W1.__name__}, {W2.__name__}>'

        try:
            return cls.types[name]
        except KeyError:
            pass

        new_type = type(name, (), {'__init__': cls.init})
        new_type.W1 = W1
        new_type.W2 = W2
        new_type.zero = classmethod(cls.zero)
        cls.types[name] = new_type

        return new_type

    def __init__(self):
        raise Exception('ProductWeight is a static class and cannot be instantiated.')

    def init(self, value1, value2):
        self.value1 = value1
        self.value2 = value2

    def zero(cls):
        return cls(cls.W1.zero(), cls.W2.zero())

weight1 = SimpleWeight(False)
weight2 = SimpleWeight(True)
product = ProductWeight[SimpleWeight, SimpleWeight](weight1, weight2)

print(SimpleWeight.zero())
print(ProductWeight[SimpleWeight, SimpleWeight].zero())

This takes an advantage of the fact that the bracket operator __getitem__ takes arbitrary number of arguments and packs them into a tuple which is its first parameter. You can unpack the tuple and get all the types. This can be extended to account for any number of types and even such a number that was chosen at runtime.

Inferred type

Last version lost the ability to infer the type from arguments passed to the constructor. By creating an abstract factory we can get this functionality back.

# SimpleWeight didn't change
from simple_weight import SimpleWeight

class ProductWeightAbstractFactory:
    def __call__(self, value1, value2):
        return self[type(value1), type(value2)](value1, value2)

    def __getitem__(self, types):
        W1, W2 = types
        name = f'ProductWeight<{W1.__name__}, {W2.__name__}>'

        try:
            return self.types[name]
        except KeyError:
            pass

        cls = type(self)
        new_type = type(name, (), {'__init__': cls.init})
        new_type.W1 = W1
        new_type.W2 = W2
        new_type.zero = classmethod(cls.zero)
        self.types[name] = new_type

        return new_type

    def __init__(self):
        self.types = {}

    def init(self, value1, value2):
        self.value1 = value1
        self.value2 = value2

    def zero(cls):
        return cls(cls.W1.zero(), cls.W2.zero())

ProductWeight = ProductWeightAbstractFactory()

weight1 = SimpleWeight(False)
weight2 = SimpleWeight(True)
product = ProductWeight[SimpleWeight, SimpleWeight](weight1, weight2)
inferred_product = ProductWeight(weight1, weight2)

print(SimpleWeight.zero())
print(ProductWeight[SimpleWeight, SimpleWeight].zero())
print(type(inferred_product).zero())

Note that you have to create an instance of the factory before using it:

ProductWeight = ProductWeightAbstractFactory()

Now you can create an object with an explicit type using the brackets:

product = ProductWeight[SimpleWeight, SimpleWeight](weight1, weight2)

Or you can infer the type to make the code concise:

product = ProductWeight(weight1, weight2)

Now it's closer to C++ like syntax than ever.

Type checking

To provide more safety when developping you can also introduce type checking in the constructor.

def init(self, value1, value2):
    def check_types(objects, required_types):
        for index, (obj, t) in enumerate(zip(objects, required_types)):
            if not issubclass(type(obj), t):
                raise Exception(f'Parameter {index + 1} is not a subclass of its required type {t}.')

    cls = type(self)
    check_types((value1, value2), (cls.W1, cls.W2))

    self.value1 = value1
    self.value2 = value2

This code will then fail:

product = ProductWeight[SimpleWeight, SimpleWeight](1, 2)

Upvotes: 8

Related Questions