Reputation: 1
I am attempting to produce a model that will accept multiple video frames as input and provide a label as output (a.k.a. video classification). I have seen code similar to the below in several locations for performing this tasks. I have a point of confusion however because in the ‘out, hidden = self.lstm(x.unsqueeze(0))’ line, out will ultimately only hold the output for the last frame once the for loop is completed, therefore the returned x at the end of the forward pass would be based solely on the last frame, yes? What makes this architecture different than processing the last frame alone? Is a CNN-LSTM model an appropriate architecture for this type of problem in the first place?
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet101
class CNNLSTM(nn.Module):
def __init__(self, num_classes=2):
super(CNNLSTM, self).__init__()
self.resnet = resnet101(pretrained=True)
self.resnet.fc = nn.Sequential(nn.Linear(self.resnet.fc.in_features, 300))
self.lstm = nn.LSTM(input_size=300, hidden_size=256, num_layers=3)
self.fc1 = nn.Linear(256, 128)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x_3d):
hidden = None
for t in range(x_3d.size(1)):
with torch.no_grad():
x = self.resnet(x_3d[:, t])
out, hidden = self.lstm(x.unsqueeze(0))
x = self.fc1(out.squeeze())
x = F.relu(x)
x = self.fc2(x)
return x
Upvotes: -1
Views: 437
Reputation: 1
I was able to get some help on another forum. Below is an architecture that will resolve the issues I was correct in worrying about.
class CNNLSTM(nn.Module):
def __init__(self, num_classes=2):
super(CNNLSTM, self).__init__()
self.resnet = resnet101(pretrained=True)
self.resnet.fc = nn.Sequential(nn.Linear(self.resnet.fc.in_features, 300))
self.lstm = nn.LSTM(input_size=300, hidden_size=256, num_layers=3)
self.fc1 = nn.Linear(256, 128)
self.fc2 = nn.Linear(128, num_classes)
def forward(self, x_3d):
hidden = None
# Iterate over each frame of a video in a video of batch * frames * channels * height * width
for t in range(x_3d.size(1)):
with torch.no_grad():
x = self.resnet(x_3d[:, t])
# Pass latent representation of frame through lstm and update hidden state
# Hidden state keeps record of information learned from prior frames
out, hidden = self.lstm(x.unsqueeze(0), hidden)
# Get the last hidden state (hidden is a tuple with both hidden and cell state in it)
x = self.fc1(hidden[0][-1])
x = F.relu(x)
x = self.fc2(x)
return x
Upvotes: 0