Reputation: 4170
I am using torch summary
from torchsummary import summary
I want to pass more than one argument when printing the model summary, but the examples mentioned here: Model summary in pytorch taken only one argument. for e.g.:
model = Network().to(device)
summary(model,(1,28,28))
The reason is that the forward function takes two arguments as input, e.g.:
def forward(self, img1, img2):
How do I pass two arguments here?
Upvotes: 11
Views: 15480
Reputation: 1
It works by modifying codes on line 10 of torchsummary.py.
Instead of this:
total_input_size = abs(np.prod(input_size) * batch_size * 4. / (1024 ** 2.))
use this:
total_input_size = abs(np.prod(input_size[0]) * batch_size * 4. / (1024 ** 2.))
My torchsummary.py version is 1.5.1. I guess the author of the gadget should update his jewel for two inputs to a model. For example, in summary(model, [(3, 256, 256), (256,256)]), the gadget should use (3,256,256) in the list for two inputs.
Upvotes: 0
Reputation: 4170
You can use the example given here: pytorch summary multiple inputs
summary(model, [(1, 16, 16), (1, 28, 28)])
Upvotes: 22