Eunice TT
Eunice TT

Reputation: 59

TypeError: hook() takes 2 positional arguments but 3 were given

I'm new to pytorch and I'm trying to use hook() and register_forward_pre_hook in my project

What I've tried is

def get_features_hook(module,input):
    print(input)

handle_feat = alexnet.features[0].register_forward_pre_hook(get_features_hook)


a = alexnet(input_data)

And I got belows error at a = alexnet(input_data)

TypeError: get_features_hook() takes 2 positional arguments but 3 were given

I've lost few hours on this problem and I just can't able to figure it out.

Anyone likes to help me?


With Shai's help, I tried his codes, and I got this

Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
get_features_hook called with 2 args:
    arg of type Conv2d
    arg of type tuple

File "<input>", line 2, in get_features_hook
NameError: name 'args' is not defined

Upvotes: 0

Views: 1014

Answers (2)

thiagocrepaldi
thiagocrepaldi

Reputation: 141

If get_features_hook is defined inside your torch.nn.Module, it should be annotated as @staticmethod, otherwise self is implicitly passed to it

Upvotes: 1

Achyut Srivastava
Achyut Srivastava

Reputation: 1

I got the same error. I've tried re-running the notebook and that solved it

Upvotes: 0

Related Questions