Reputation: 21
I wrote a simple wrapper to add special methods to a given PyTorch neural network.
While the implementation below works well for general objects like strings, lists etc. I get a RecursionError
when applying it to a torch.nn.Module
. It seems that in the latter case the call to self.instance
inside the __getattr__
method is unsuccessful, so it falls back to __getattr__
again, leading to the infinite loop (I also tried self.__dict__['instane']
without luck).
I assume that this behaviour stems from the implementations of the __getattr__
and __setattr__
methods torch.nn.Module
but after inspecting their implementations I still don't see how.
I would like to understand in detail what is going on and how to fix the error in my implementation.
(I am aware of the similar question in link but it does not answer my question.)
Here is a minimal implementation to recreate the my situation.
import torch
class MyWrapper(torch.nn.Module):
def __init__(self, instance):
super().__init__()
self.instance = instance
def __getattr__(self, name):
print("trace", name)
return getattr(self.instance, name)
# Working example
obj = "test string"
obj_wrapped = MyWrapper(obj)
print(obj_wrapped.split(" ")) # trace split\n ['test', 'string']
# Failing example
net = torch.nn.Linear(12, 12)
net.test_attribute = "hello world"
b = MyWrapper(net)
print(b.test_attribute) # RecursionError: maximum recursion depth exceeded
b.instance # RecursionError: maximum recursion depth exceeded
Upvotes: 2
Views: 200
Reputation: 41987
The error has nothing much to do with torch.nn.Module
(or any superclass/subclass of it for that matter). It's due to how attribute look-up works in Python classes.
As you've overridden the __getattr__
special method in MyWrapper
class, when you do self.instance
inside __getattr__
, it's getting into an infinite recursive situation to get the attribute named instance
as it's looking into the __getattr__
of the current object's (self
) class (MyWrapper
) again (and again) and failing.
Fix:
You can take help from the fact that Python allows you to use superclass's __getattr__
method (easily accessible using the super
method). So if we use superclass's __getattr__
to get the instance
resolution correctly, then we can still use getattr
to get the next name
lookup. For example:
In [259]: class MyWrapper(torch.nn.Module):
...: def __init__(self, instance):
...: super().__init__()
...: self.instance = instance
...:
...: def __getattr__(self, name):
...: instance = super().__getattr__("instance")
...: return getattr(instance, name)
...:
In [260]: # Your failing example - now working
...: net = torch.nn.Linear(12, 12)
...: net.test_attribute = "hello world"
...: b = MyWrapper(net)
In [261]: print(b.test_attribute)
hello world
Upvotes: 1