Reputation: 4820
I Have the following model:
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyModel(nn.Module):
def __init__(self, input_size, num_classes):
super(MyModel, self).__init__()
self.layer_1 = nn.Conv1d(1, 16, 3, bias=False, stride=2)
self.activation_1 = F.relu
self.adap = nn.AdaptiveAvgPool1d(1)
self.flatten = nn.Flatten (),
self.layer_2 = torch.nn.Linear(2249, 500)
self.activation_2 = F.relu
self.layer_3 = torch.nn.Linear(500, 2)
pass
def forward(self, x, labels=None):
x = x.reshape(256, 1, -1)
x = self.layer_1(x)
x = self.activation_1(x)
x = self.flatten(x)
return x
When running torchinfo
model = MyModel(input_size=4500, num_classes=2)
torchinfo.summary(model, (256, 4500))
I’m Getting error:
Input In [101], in MyModel.forward(self, x, labels)
30 x = self.activation_1(x)
—> 31 x = self.flatten(x)
32 return x
TypeError: ‘tuple’ object is not callable
Upvotes: 0
Views: 1089
Reputation: 366
You have ,
at the end of the flatten line.
please remove it
import torch
import torch.nn as nn
import torch.nn.functional as F
class MyModel(nn.Module):
def __init__(self, input_size, num_classes):
super(MyModel, self).__init__()
self.layer_1 = nn.Conv1d(1, 16, 3, bias=False, stride=2)
self.activation_1 = F.relu
self.adap = nn.AdaptiveAvgPool1d(1)
self.flatten = nn.Flatten()
self.layer_2 = torch.nn.Linear(2249, 500)
self.activation_2 = F.relu
self.layer_3 = torch.nn.Linear(500, 2)
pass
def forward(self, x, labels=None):
x = x.reshape(256, 1, -1)
x = self.layer_1(x)
x = self.activation_1(x)
x = self.flatten(x)
return x
Upvotes: 2