Reputation: 59
I'm new to pytorch and I'm trying to use hook()
and register_forward_pre_hook
in my project
What I've tried is
def get_features_hook(module,input):
print(input)
handle_feat = alexnet.features[0].register_forward_pre_hook(get_features_hook)
a = alexnet(input_data)
And I got belows error at a = alexnet(input_data)
TypeError: get_features_hook() takes 2 positional arguments but 3 were given
I've lost few hours on this problem and I just can't able to figure it out.
Anyone likes to help me?
With Shai's help, I tried his codes, and I got this
Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
get_features_hook called with 2 args:
arg of type Conv2d
arg of type tuple
File "<input>", line 2, in get_features_hook
NameError: name 'args' is not defined
Upvotes: 0
Views: 1014
Reputation: 141
If get_features_hook
is defined inside your torch.nn.Module
, it should be annotated as @staticmethod
, otherwise self
is implicitly passed to it
Upvotes: 1
Reputation: 1
I got the same error. I've tried re-running the notebook and that solved it
Upvotes: 0