GooseFromHell
GooseFromHell

Reputation: 21

Pythons `__getattr__` + `torch.nn.Module` yields infinite recursion

I wrote a simple wrapper to add special methods to a given PyTorch neural network. While the implementation below works well for general objects like strings, lists etc. I get a RecursionError when applying it to a torch.nn.Module. It seems that in the latter case the call to self.instance inside the __getattr__ method is unsuccessful, so it falls back to __getattr__ again, leading to the infinite loop (I also tried self.__dict__['instane'] without luck).

I assume that this behaviour stems from the implementations of the __getattr__ and __setattr__ methods torch.nn.Module but after inspecting their implementations I still don't see how.

I would like to understand in detail what is going on and how to fix the error in my implementation.

(I am aware of the similar question in link but it does not answer my question.)

Here is a minimal implementation to recreate the my situation.

import torch

class MyWrapper(torch.nn.Module):
    def __init__(self, instance):
        super().__init__()
        self.instance = instance

    def __getattr__(self, name):
        print("trace", name)
        return getattr(self.instance, name)

# Working example
obj = "test string"
obj_wrapped = MyWrapper(obj)
print(obj_wrapped.split(" ")) # trace split\n ['test', 'string']

# Failing example
net = torch.nn.Linear(12, 12)
net.test_attribute = "hello world"
b = MyWrapper(net)

print(b.test_attribute) # RecursionError: maximum recursion depth exceeded
b.instance # RecursionError: maximum recursion depth exceeded

Upvotes: 2

Views: 200

Answers (1)

heemayl
heemayl

Reputation: 41987

The error has nothing much to do with torch.nn.Module (or any superclass/subclass of it for that matter). It's due to how attribute look-up works in Python classes.

As you've overridden the __getattr__ special method in MyWrapper class, when you do self.instance inside __getattr__, it's getting into an infinite recursive situation to get the attribute named instance as it's looking into the __getattr__ of the current object's (self) class (MyWrapper) again (and again) and failing.

Fix:

You can take help from the fact that Python allows you to use superclass's __getattr__ method (easily accessible using the super method). So if we use superclass's __getattr__ to get the instance resolution correctly, then we can still use getattr to get the next name lookup. For example:


    In [259]: class MyWrapper(torch.nn.Module):
         ...:     def __init__(self, instance):
         ...:         super().__init__()
         ...:         self.instance = instance
         ...: 
         ...:     def __getattr__(self, name):
         ...:         instance = super().__getattr__("instance")
         ...:         return getattr(instance, name)
         ...:         
    
    In [260]: # Your failing example - now working
         ...: net = torch.nn.Linear(12, 12)
         ...: net.test_attribute = "hello world"
         ...: b = MyWrapper(net)
    
    In [261]: print(b.test_attribute)
    hello world

Upvotes: 1

Related Questions