PixelPioneer
PixelPioneer

Reputation: 4170

pytorch model summary - forward func has more than one argument

I am using torch summary

from torchsummary import summary

I want to pass more than one argument when printing the model summary, but the examples mentioned here: Model summary in pytorch taken only one argument. for e.g.:

model = Network().to(device)
summary(model,(1,28,28))

The reason is that the forward function takes two arguments as input, e.g.:

def forward(self, img1, img2):

How do I pass two arguments here?

Upvotes: 11

Views: 15480

Answers (2)

caide xiao
caide xiao

Reputation: 1

It works by modifying codes on line 10 of torchsummary.py.

Instead of this:

total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))

use this:

total_input_size = abs(np.prod(input_size[0]) * batch_size * 4. / (1024 ** 2.))

My torchsummary.py version is 1.5.1. I guess the author of the gadget should update his jewel for two inputs to a model. For example, in summary(model, [(3, 256, 256), (256,256)]), the gadget should use (3,256,256) in the list for two inputs.

Upvotes: 0

PixelPioneer
PixelPioneer

Reputation: 4170

You can use the example given here: pytorch summary multiple inputs

summary(model, [(1, 16, 16), (1, 28, 28)])

Upvotes: 22

Related Questions