How do I print the model summary in PyTorch?

How do I print the summary of a model in PyTorch like what model.summary() does in Keras:

Model Summary:
Layer (type)                     Output Shape          Param #     Connected to                     
input_1 (InputLayer)             (None, 1, 15, 27)     0                                            
convolution2d_1 (Convolution2D)  (None, 8, 15, 27)     872         input_1[0][0]                    
maxpooling2d_1 (MaxPooling2D)    (None, 8, 7, 27)      0           convolution2d_1[0][0]            
flatten_1 (Flatten)              (None, 1512)          0           maxpooling2d_1[0][0]             
dense_1 (Dense)                  (None, 1)             1513        flatten_1[0][0]                  
Total params: 2,385
Trainable params: 2,385
Non-trainable params: 0

While you will not get as detailed information about the model as in Keras' model.summary, simply printing the model will give you some idea about the different layers involved and their specifications.

For instance:

from torchvision import models
model = models.vgg16()

The output in this case would be something as follows:

  (features): Sequential (
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU (inplace)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU (inplace)
    (4): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU (inplace)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU (inplace)
    (9): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU (inplace)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU (inplace)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU (inplace)
    (16): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (18): ReLU (inplace)
    (19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (20): ReLU (inplace)
    (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (22): ReLU (inplace)
    (23): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
    (24): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (25): ReLU (inplace)
    (26): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (27): ReLU (inplace)
    (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (29): ReLU (inplace)
    (30): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1))
  (classifier): Sequential (
    (0): Dropout (p = 0.5)
    (1): Linear (25088 -> 4096)
    (2): ReLU (inplace)
    (3): Dropout (p = 0.5)
    (4): Linear (4096 -> 4096)
    (5): ReLU (inplace)
    (6): Linear (4096 -> 1000)

Now you could, as mentioned by Kashyap, use the state_dict method to get the weights of the different layers. But using this listing of the layers would perhaps provide more direction is creating a helper function to get that Keras like model summary!

I prefer this simple snippet instead -

net = model
modules = [module for module in net.modules()]
params = [param.shape for param in net.parameters()]

# Print Model Summary
for i in range(1,len(modules)):
   j = 2*i
   param = (params[j-2][1]*params[j-2][0])+params[j-1][0]
   total_params += param
   print("Weights:", params[j-2][0],"x",params[j-2][1],
         "\tBias: ",params[j-1][0], "\tParameters: ", param)
print("\nTotal Params: ", total_params)

This prints everything I need -

  (hLayer1): Linear(in_features=1024, out_features=256, bias=True)
  (hLayer2): Linear(in_features=256, out_features=128, bias=True)
  (hLayer3): Linear(in_features=128, out_features=64, bias=True)
  (outLayer): Linear(in_features=64, out_features=10, bias=True)
Layer 1 ->  Weights: 256 x 1024     Bias:  256  Parameters:  262400
Layer 2 ->  Weights: 128 x 256      Bias:  128  Parameters:  32896
Layer 3 ->  Weights: 64 x 128       Bias:  64   Parameters:  8256
Layer 4 ->  Weights: 10 x 64        Bias:  10   Parameters:  650

Total Parameters:  304202

For a complex model or a more indepth stats of the model

Install torchstat

pip install torchstat

Get stats

from torchstat import stat
import torchvision.models as models

model = models.vgg19()
stat(model, (3, 224, 224))

Output -

        module name  input shape output shape       params memory(MB)              MAdd             Flops   MemRead(B)  MemWrite(B) duration[%]    MemR+W(B)
0        features.0    3 224 224   64 224 224       1792.0      12.25     173,408,256.0      89,915,392.0     609280.0   12845056.0      21.56%   13454336.0
1        features.1   64 224 224   64 224 224          0.0      12.25       3,211,264.0       3,211,264.0   12845056.0   12845056.0       0.92%   25690112.0
2        features.2   64 224 224   64 224 224      36928.0      12.25   3,699,376,128.0   1,852,899,328.0   12992768.0   12845056.0       4.74%   25837824.0
3        features.3   64 224 224   64 224 224          0.0      12.25       3,211,264.0       3,211,264.0   12845056.0   12845056.0       0.92%   25690112.0
4        features.4   64 224 224   64 112 112          0.0       3.06       2,408,448.0       3,211,264.0   12845056.0    3211264.0       1.22%   16056320.0
5        features.5   64 112 112  128 112 112      73856.0       6.12   1,849,688,064.0     926,449,664.0    3506688.0    6422528.0       4.71%    9929216.0
6        features.6  128 112 112  128 112 112          0.0       6.12       1,605,632.0       1,605,632.0    6422528.0    6422528.0       0.94%   12845056.0
7        features.7  128 112 112  128 112 112     147584.0       6.12   3,699,376,128.0   1,851,293,696.0    7012864.0    6422528.0       4.36%   13435392.0
8        features.8  128 112 112  128 112 112          0.0       6.12       1,605,632.0       1,605,632.0    6422528.0    6422528.0       0.91%   12845056.0
9        features.9  128 112 112  128  56  56          0.0       1.53       1,204,224.0       1,605,632.0    6422528.0    1605632.0       1.51%    8028160.0
10      features.10  128  56  56  256  56  56     295168.0       3.06   1,849,688,064.0     925,646,848.0    2786304.0    3211264.0       3.57%    5997568.0
11      features.11  256  56  56  256  56  56          0.0       3.06         802,816.0         802,816.0    3211264.0    3211264.0       0.90%    6422528.0
12      features.12  256  56  56  256  56  56     590080.0       3.06   3,699,376,128.0   1,850,490,880.0    5571584.0    3211264.0       4.30%    8782848.0
13      features.13  256  56  56  256  56  56          0.0       3.06         802,816.0         802,816.0    3211264.0    3211264.0       0.90%    6422528.0
14      features.14  256  56  56  256  56  56     590080.0       3.06   3,699,376,128.0   1,850,490,880.0    5571584.0    3211264.0       4.38%    8782848.0
15      features.15  256  56  56  256  56  56          0.0       3.06         802,816.0         802,816.0    3211264.0    3211264.0       0.94%    6422528.0
16      features.16  256  56  56  256  56  56     590080.0       3.06   3,699,376,128.0   1,850,490,880.0    5571584.0    3211264.0       4.33%    8782848.0
17      features.17  256  56  56  256  56  56          0.0       3.06         802,816.0         802,816.0    3211264.0    3211264.0       0.90%    6422528.0
18      features.18  256  56  56  256  28  28          0.0       0.77         602,112.0         802,816.0    3211264.0     802816.0       1.44%    4014080.0
19      features.19  256  28  28  512  28  28    1180160.0       1.53   1,849,688,064.0     925,245,440.0    5523456.0    1605632.0       3.60%    7129088.0
20      features.20  512  28  28  512  28  28          0.0       1.53         401,408.0         401,408.0    1605632.0    1605632.0       0.92%    3211264.0
21      features.21  512  28  28  512  28  28    2359808.0       1.53   3,699,376,128.0   1,850,089,472.0   11044864.0    1605632.0       4.45%   12650496.0
22      features.22  512  28  28  512  28  28          0.0       1.53         401,408.0         401,408.0    1605632.0    1605632.0       0.94%    3211264.0
23      features.23  512  28  28  512  28  28    2359808.0       1.53   3,699,376,128.0   1,850,089,472.0   11044864.0    1605632.0       4.39%   12650496.0
24      features.24  512  28  28  512  28  28          0.0       1.53         401,408.0         401,408.0    1605632.0    1605632.0       0.90%    3211264.0
25      features.25  512  28  28  512  28  28    2359808.0       1.53   3,699,376,128.0   1,850,089,472.0   11044864.0    1605632.0       4.34%   12650496.0
26      features.26  512  28  28  512  28  28          0.0       1.53         401,408.0         401,408.0    1605632.0    1605632.0       0.90%    3211264.0
27      features.27  512  28  28  512  14  14          0.0       0.38         301,056.0         401,408.0    1605632.0     401408.0       0.96%    2007040.0
28      features.28  512  14  14  512  14  14    2359808.0       0.38     924,844,032.0     462,522,368.0    9840640.0     401408.0       0.99%   10242048.0
29      features.29  512  14  14  512  14  14          0.0       0.38         100,352.0         100,352.0     401408.0     401408.0       0.00%     802816.0
30      features.30  512  14  14  512  14  14    2359808.0       0.38     924,844,032.0     462,522,368.0    9840640.0     401408.0       0.11%   10242048.0
31      features.31  512  14  14  512  14  14          0.0       0.38         100,352.0         100,352.0     401408.0     401408.0       0.00%     802816.0
32      features.32  512  14  14  512  14  14    2359808.0       0.38     924,844,032.0     462,522,368.0    9840640.0     401408.0       0.11%   10242048.0
33      features.33  512  14  14  512  14  14          0.0       0.38         100,352.0         100,352.0     401408.0     401408.0       0.00%     802816.0
34      features.34  512  14  14  512  14  14    2359808.0       0.38     924,844,032.0     462,522,368.0    9840640.0     401408.0       0.11%   10242048.0
35      features.35  512  14  14  512  14  14          0.0       0.38         100,352.0         100,352.0     401408.0     401408.0       0.00%     802816.0
36      features.36  512  14  14  512   7   7          0.0       0.10          75,264.0         100,352.0     401408.0     100352.0       0.01%     501760.0
37          avgpool  512   7   7  512   7   7          0.0       0.10               0.0               0.0          0.0          0.0       0.49%          0.0
38     classifier.0        25088         4096  102764544.0       0.02     205,516,800.0     102,760,448.0  411158528.0      16384.0      11.27%  411174912.0
39     classifier.1         4096         4096          0.0       0.02           4,096.0           4,096.0      16384.0      16384.0       0.00%      32768.0
40     classifier.2         4096         4096          0.0       0.02               0.0               0.0          0.0          0.0       0.01%          0.0
41     classifier.3         4096         4096   16781312.0       0.02      33,550,336.0      16,777,216.0   67141632.0      16384.0       1.08%   67158016.0
42     classifier.4         4096         4096          0.0       0.02           4,096.0           4,096.0      16384.0      16384.0       0.00%      32768.0
43     classifier.5         4096         4096          0.0       0.02               0.0               0.0          0.0          0.0       0.00%          0.0
44     classifier.6         4096         1000    4097000.0       0.00       8,191,000.0       4,096,000.0   16404384.0       4000.0       0.93%   16408384.0
total                                          143667240.0     119.34  39,283,567,128.0  19,667,896,320.0   16404384.0       4000.0     100.00%  825282624.0
Total params: 143,667,240
Total memory: 119.34 MB
Total MAdd: 39.28 GMAdd
Total Flops: 19.67 GFlops
Total MemR+W: 787.05 MB

The torchinfo (formerly torchsummary) package produces analogous output to Keras1 (for a given input shape):2

from torchinfo import summary

model = ConvNet()
batch_size = 16
summary(model, input_size=(batch_size, 1, 28, 28))
Layer (type:depth-idx)                   Output Shape              Param #
├─Conv2d (conv1): 1-1                    [5, 10, 24, 24]           260
├─Conv2d (conv2): 1-2                    [5, 20, 8, 8]             5,020
├─Dropout2d (conv2_drop): 1-3            [5, 20, 8, 8]             --
├─Linear (fc1): 1-4                      [5, 50]                   16,050
├─Linear (fc2): 1-5                      [5, 10]                   510
Total params: 21,840
Trainable params: 21,840
Non-trainable params: 0
Total mult-adds (M): 7.69
Input size (MB): 0.05
Forward/backward pass size (MB): 0.91
Params size (MB): 0.09
Estimated Total Size (MB): 1.05


  1. Torchinfo provides information complementary to what is provided by print(your_model) in PyTorch, similar to Tensorflow's model.summary()...

  2. Unlike Keras, PyTorch has a dynamic computational graph which can adapt to any compatible input shape across multiple calls e.g. any sufficiently large image size (for a fully convolutional network).

    As such, it cannot present an inherent set of input/output shapes for each layer, as these are input-dependent, and why in the above package you must specify the input dimensions.

In order to use torchsummary type:

from torchsummary import summary

Install it first if you don't have it.

pip install torchsummary 

And then you can try it, but note for some reason it is not working unless I set model to cuda alexnet.cuda:

from torchsummary import summary
import torchvision.models as models
alexnet = models.alexnet(pretrained=False)
summary(alexnet, (3, 224, 224))

The summary must take the input size and batch size is set to -1 meaning any batch size we provide.

If we set summary(alexnet, (3, 224, 224), 32) this means use the bs=32.

summary(model, input_size, batch_size=-1, device='cuda')


Help on function summary in module torchsummary.torchsummary:

summary(model, input_size, batch_size=-1, device='cuda')

        Layer (type)               Output Shape         Param #
            Conv2d-1           [32, 64, 55, 55]          23,296
              ReLU-2           [32, 64, 55, 55]               0
         MaxPool2d-3           [32, 64, 27, 27]               0
            Conv2d-4          [32, 192, 27, 27]         307,392
              ReLU-5          [32, 192, 27, 27]               0
         MaxPool2d-6          [32, 192, 13, 13]               0
            Conv2d-7          [32, 384, 13, 13]         663,936
              ReLU-8          [32, 384, 13, 13]               0
            Conv2d-9          [32, 256, 13, 13]         884,992
             ReLU-10          [32, 256, 13, 13]               0
           Conv2d-11          [32, 256, 13, 13]         590,080
             ReLU-12          [32, 256, 13, 13]               0
        MaxPool2d-13            [32, 256, 6, 6]               0
AdaptiveAvgPool2d-14            [32, 256, 6, 6]               0
          Dropout-15                 [32, 9216]               0
           Linear-16                 [32, 4096]      37,752,832
             ReLU-17                 [32, 4096]               0
          Dropout-18                 [32, 4096]               0
           Linear-19                 [32, 4096]      16,781,312
             ReLU-20                 [32, 4096]               0
           Linear-21                 [32, 1000]       4,097,000
Total params: 61,100,840
Trainable params: 61,100,840
Non-trainable params: 0
Input size (MB): 18.38
Forward/backward pass size (MB): 268.12
Params size (MB): 233.08
Estimated Total Size (MB): 519.58
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
    (2): ReLU(inplace)
    (3): Dropout(p=0.5)
    (4): Linear(in_features=4096, out_features=4096, bias=True)
    (5): ReLU(inplace)
    (6): Linear(in_features=4096, out_features=1000, bias=True)

summary(my_model, (3, 224, 224), device = 'cpu') will solve the issue.

Yes, you can get exact Keras representation, using the pytorch-summary package.

Example for VGG16:

from torchvision import models
from torchsummary import summary

vgg = models.vgg16()
summary(vgg, (3, 224, 224))

        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 224, 224]           1,792
              ReLU-2         [-1, 64, 224, 224]               0
            Conv2d-3         [-1, 64, 224, 224]          36,928
              ReLU-4         [-1, 64, 224, 224]               0
         MaxPool2d-5         [-1, 64, 112, 112]               0
            Conv2d-6        [-1, 128, 112, 112]          73,856
              ReLU-7        [-1, 128, 112, 112]               0
            Conv2d-8        [-1, 128, 112, 112]         147,584
              ReLU-9        [-1, 128, 112, 112]               0
        MaxPool2d-10          [-1, 128, 56, 56]               0
           Conv2d-11          [-1, 256, 56, 56]         295,168
             ReLU-12          [-1, 256, 56, 56]               0
           Conv2d-13          [-1, 256, 56, 56]         590,080
             ReLU-14          [-1, 256, 56, 56]               0
           Conv2d-15          [-1, 256, 56, 56]         590,080
             ReLU-16          [-1, 256, 56, 56]               0
        MaxPool2d-17          [-1, 256, 28, 28]               0
           Conv2d-18          [-1, 512, 28, 28]       1,180,160
             ReLU-19          [-1, 512, 28, 28]               0
           Conv2d-20          [-1, 512, 28, 28]       2,359,808
             ReLU-21          [-1, 512, 28, 28]               0
           Conv2d-22          [-1, 512, 28, 28]       2,359,808
             ReLU-23          [-1, 512, 28, 28]               0
        MaxPool2d-24          [-1, 512, 14, 14]               0
           Conv2d-25          [-1, 512, 14, 14]       2,359,808
             ReLU-26          [-1, 512, 14, 14]               0
           Conv2d-27          [-1, 512, 14, 14]       2,359,808
             ReLU-28          [-1, 512, 14, 14]               0
           Conv2d-29          [-1, 512, 14, 14]       2,359,808
             ReLU-30          [-1, 512, 14, 14]               0
        MaxPool2d-31            [-1, 512, 7, 7]               0
           Linear-32                 [-1, 4096]     102,764,544
             ReLU-33                 [-1, 4096]               0
          Dropout-34                 [-1, 4096]               0
           Linear-35                 [-1, 4096]      16,781,312
             ReLU-36                 [-1, 4096]               0
          Dropout-37                 [-1, 4096]               0
           Linear-38                 [-1, 1000]       4,097,000
Total params: 138,357,544
Trainable params: 138,357,544
Non-trainable params: 0
Input size (MB): 0.57
Forward/backward pass size (MB): 218.59
Params size (MB): 527.79
Estimated Total Size (MB): 746.96

Keras like model summary using torchsummary:

from torchsummary import summary
summary(model, input_size=(3, 224, 224))

You can use

from torchsummary import summary

You can specify device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

You can create a Network, and if you are using MNIST datasets, then following commands will work and show you summary

model = Network().to(device)

Simply print the model after defining an object for the model class

class RNN(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, output_dim):

        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.rnn = nn.RNN(embedding_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, output_dim)
    def forward():

model = RNN(input_dim, embedding_dim, hidden_dim, output_dim)

Simplest to remember (not as pretty as Keras):


This also work:


If you just want the number of parameters:

sum([param.nelement() for param in model.parameters()])

From: Is there similar pytorch function as model.summary() as keras? (forum.PyTorch.org)

This will show a model's weights and parameters (but not output shape).

from torch.nn.modules.module import _addindent
import torch
import numpy as np
def torch_summarize(model, show_weights=True, show_parameters=True):
    """Summarizes torch model by showing trainable parameters and weights."""
    tmpstr = model.__class__.__name__ + ' (\n'
    for key, module in model._modules.items():
        # if it contains layers let call it recursively to get params and weights
        if type(module) in [
            modstr = torch_summarize(module)
            modstr = module.__repr__()
        modstr = _addindent(modstr, 2)

        params = sum([np.prod(p.size()) for p in module.parameters()])
        weights = tuple([tuple(p.size()) for p in module.parameters()])

        tmpstr += '  (' + key + '): ' + modstr 
        if show_weights:
            tmpstr += ', weights={}'.format(weights)
        if show_parameters:
            tmpstr +=  ', parameters={}'.format(params)
        tmpstr += '\n'   

    tmpstr = tmpstr + ')'
    return tmpstr

# Test
import torchvision.models as models
model = models.alexnet()

# # Output
# AlexNet (
#   (features): Sequential (
#     (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2)), weights=((64, 3, 11, 11), (64,)), parameters=23296
#     (1): ReLU (inplace), weights=(), parameters=0
#     (2): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1)), weights=(), parameters=0
#     (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)), weights=((192, 64, 5, 5), (192,)), parameters=307392
#     (4): ReLU (inplace), weights=(), parameters=0
#     (5): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1)), weights=(), parameters=0
#     (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), weights=((384, 192, 3, 3), (384,)), parameters=663936
#     (7): ReLU (inplace), weights=(), parameters=0
#     (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), weights=((256, 384, 3, 3), (256,)), parameters=884992
#     (9): ReLU (inplace), weights=(), parameters=0
#     (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)), weights=((256, 256, 3, 3), (256,)), parameters=590080
#     (11): ReLU (inplace), weights=(), parameters=0
#     (12): MaxPool2d (size=(3, 3), stride=(2, 2), dilation=(1, 1)), weights=(), parameters=0
#   ), weights=((64, 3, 11, 11), (64,), (192, 64, 5, 5), (192,), (384, 192, 3, 3), (384,), (256, 384, 3, 3), (256,), (256, 256, 3, 3), (256,)), parameters=2469696
#   (classifier): Sequential (
#     (0): Dropout (p = 0.5), weights=(), parameters=0
#     (1): Linear (9216 -> 4096), weights=((4096, 9216), (4096,)), parameters=37752832
#     (2): ReLU (inplace), weights=(), parameters=0
#     (3): Dropout (p = 0.5), weights=(), parameters=0
#     (4): Linear (4096 -> 4096), weights=((4096, 4096), (4096,)), parameters=16781312
#     (5): ReLU (inplace), weights=(), parameters=0
#     (6): Linear (4096 -> 1000), weights=((1000, 4096), (1000,)), parameters=4097000
#   ), weights=((4096, 9216), (4096,), (4096, 4096), (4096,), (1000, 4096), (1000,)), parameters=58631144
# )

Edit: isaykatsman has a pytorch PR to add a model.summary() that is exactly like keras https://github.com/pytorch/pytorch/pull/3043/files

