Reputation: 87
I have an input sequence in the form sequence_len x C x H x W = [10, 3, 16, 16] (Assuming batch size = 1). These are 10 images stacked together in a torch tensor. I wish to pass this to an MLP and obtain the next 10 as predictions from the MLP. The structure of MLP has one hidden layer with 32 units. If I flatten the input from dimension 1 - [10, 768]
My current code looks like this:
class MLP3(nn.Module):
def __init__(self, ip_layers):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(ip_layers, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 10)
)
def forward(self, x):
#forward
return self.layers(x)
However, I am unable to pass the entire tensor and not sure how I can obtain 10 outputs from the MLP. Any help will be highly appreciated. TIA
Upvotes: 0
Views: 408
Reputation: 1888
To get 10 outputs at the end you have to use a multi-channel model. For that, you have to define the base_model
in your __init__
function and modify your other datasets and the forward function and write loss for every class. I will write a boilerplate for you
class SDataset(torch.nn.utils.Dataset):
def __init__(self):
....
def __getitem__(self, idx):
....
'''
Till now everything was the same as before.
Now the below thing is something that you have to change.
'''
outimg1 = self.image_seq[idx+1]
outimg2 = self.image_seq[idx+2]
.
.
outimg10 = self.image_seq[idx+10]
return X, (outimg1, outimg2 ..... outimg10)
The model
class Predictor(torch.nn.Module):
def __init__(self):
super(Predictor, self).__init__()
self.base_model = somebasemodel
self.out1 = # Takes input from base_model and output an image
self.out2 = # ....
.
.
self.out10 = # ...
def forward(self, x):
features = self.base_model(x)
res1 = self.out1(features)
.
.
res10 = self.out10(features)
return res1, res2, .... res10
The training step
inputs, outputs = batch
preds = model(inputs)
optimizer.zero_grad()
loss1 = criterion(preds[0], outputs[0])
.
.
loss10 = criterion(preds[9], outputs[9])
total_loss = loss1 + loss2 + ... + loss10
total_loss.backward()
optimizer.step()
Now that was the most of it.
Upvotes: 0