Yaroslav Bulatov
Yaroslav Bulatov

Reputation: 57973

why need def __get__: return types.MethodType?

I'm trying to understand this block of code

class primitive(object):
    ...
    def __get__(self, obj, objtype):
        if sys.version_info >= (3,):
            def __get__(self, obj, objtype):
                return types.MethodType(self, obj)
        else:
            def __get__(self, obj, objtype):
                return types.MethodType(self, obj, objtype)

Does anyone have an example when this would come up/why I need this?

Upvotes: 2

Views: 316

Answers (2)

CodenameLambda
CodenameLambda

Reputation: 1496

the code you posted is used as a descriptor. That has following effect: If a class has an object of a descriptor, than an instance has an attribute that has the same name as the object in the class.

If you set that attribute, the __set__(self, instance, value) command of the descriptor is called.

If you delete it, the __delete__(self, instance) function if the descriptor is called.

And if you try to receive the data stored in that attribute, the __get__(self, instance, owner) method of the descriptor is called. (owner is the class that contains the descriptor object)

The self argument is the descriptor itself (just like in any other object in python), and the instance argument is the object containing the attribute that is modified.

So in this case receiving the data of an attribute with an underlying primitive results in types.MethodType(self, instance) for py2 or in types.MethodType(self, instance, owner) in py3, where self is the primitive, instance is the object that's attribute is retrieved and owner is the class holding the primitive object. (As explained earlier)

I hope I could help,

CodenameLambda

Upvotes: 0

tijko
tijko

Reputation: 8322

In python whenever a class defines either __get__, __set__, or __delete__ it is said to be a descriptor class. These give a class attribute "binding" behavior. This basically means whenever you access that object as an attribute through a class using the usually dot-notation it will run one those defined methods depending on type of call being made. The code you posted defines just __get__ which makes it a non-data descriptor.

There is another dunder method overridden here that comes into play, __call__ this makes your class a callable object:

Class CallableClass(object):

    def __init__(self, fun):
        self.fun = fun

    def __call__(self, *args):
        return self.fun(*args)

>>> cc = CallableClass(lambda *args: return sum(args))
>>> cc(1, 2, 3)
6
>>> cc(0)
0

As you can see, you can make calls on the instance as much as you like just like any other callable (e.g. functions). I'm going over this because the descriptor class returns types.MethodType(self, obj) or types.MethodType(self, obj, objtype) depending on which python version you are using.

MethodType binds its first argument, which must be callable to its second argument which is a class instance. Essentially you are creating a bound method on a class instance object every time you access the primitive descriptor object.

The "descriptor" features here are only really being used if it is used as a class attribute, reading through the primitive docstring it mentions that the class wraps functions as a decorator.

Some lines down you can see it in action as a decorator:

@primitive
def merge_tapes(x, y): return x
merge_tapes.defgrad(lambda ans, x, y : lambda g : g)
merge_tapes.defgrad(lambda ans, x, y : lambda g : g, argnum=1)

But used as a descriptor class here:

differentiable_ops = ['__add__', '__sub__', '__mul__', '__pow__', '__mod__',
                      '__neg__', '__radd__', '__rsub__', '__rmul__', '__rpow__',
                      '__rmod__', DIV, RDIV]

nondifferentiable_ops = ['__eq__', '__ne__', '__gt__', '__ge__', '__lt__', '__le__',]
for float_op in differentiable_ops + nondifferentiable_ops:
    setattr(FloatNode, float_op, primitive(getattr(float, float_op)))

Here as you can see the class FloatNode is calling setattr on all the strings from the two "ops" lists. In that same setattr call primitive is making a call to getattr that retrieves the builtin methods of the same name from type float passing it in as its initial func argument. Now whenever you access any of those operations they are bound methods.

So if you call on one of those "ops" that were set as attributes of FloatNode:

>> FloatNode(1, []).__add__
<bound method __add__ of <__main__.FloatNode object at 0xb6fd61ec>>

You will get a bound method that encapsulates all the benefits that primitive holds (i.e. the gradient functions).

Upvotes: 2

Related Questions