Ling
Ling

Reputation: 359

How to train Pytorch CNN with two or more inputs

I have a big image, multiple events in the image can impact the classification. I am thinking to split big image into small chunks and get features from each chunk and concatenate outputs together for prediction.

My code is like:

train_load_1 = DataLoader(dataset=train_dataset_1, batch_size=100, shuffle=False)
train_load_2 = DataLoader(dataset=train_dataset_2, batch_size=100, shuffle=False)
train_load_3 = DataLoader(dataset=train_dataset_3, batch_size=100, shuffle=False)

test_load_1 = DataLoader(dataset=test_dataset_1, batch_size=100, shuffle=True)
test_load_2 = DataLoader(dataset=test_dataset_2, batch_size=100, shuffle=True)
test_load_3 = DataLoader(dataset=test_dataset_3, batch_size=100, shuffle=True)

class Net(nn.Module): 
   def __init__(self):
      super(Net, self).__init__()
      self.conv = nn.Conv2d( ... )  # set up your layer here
      self.fc1 = nn.Linear( ... )  # set up first FC layer
      self.fc2 = nn.Linear( ... )  # set up the other FC layer

   def forward(self, x1, x2, x3): 
      o1 = self.conv(x1)
      o2 = self.conv(x2)
      o3 = self.conv(x3)
      combined = torch.cat((o1.view(c.size(0), -1),
                            o2.view(c.size(0), -1),
                            o3.view(c.size(0), -1)), dim=1)
      out = self.fc1(combined)
      out = self.fc2(out)
      return F.softmax(x, dim=1)

model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01)

for epoch in epochs: 
   model.train()
   
   for batch_idx, (inputs, labels) in enumerate(train_loader_1): 
   **### I am stuck here, how to enumerate all three train_loader to pass input_1, input_2, input_3 into model and share the same label? Please note in train_loader I have set shuffle=False, this is to make sure train_loader_1, train_loader_2, train_loader_3 are getting the same label ** 

Thank you for your help!

Upvotes: 2

Views: 3509

Answers (2)

Srujan2k21
Srujan2k21

Reputation: 306

For having the image parts in that format:
You can loop over the images and append them to a list or a numpy array.

def make_parts(full_image):
    # some code
    # returns a list of image parts after converting them into torch tensors
    return [TorchTensor_of_part1, TorchTensor_of_part2, TorchTensor_of_part3]

list_of_parts_and_labels = []
for image,label in zip(full_img_data, labels):
    parts = make_parts(image)
    list_of_parts_and_labels.append([parts, torch.tensor(label)])

If you wanna load your images into dataLoader, assuming that you already have your image parts and labels in the above mentioned format:

train_loader = torch.utils.data.DataLoader(list_of_parts_and_labels,
               shuffle = True, batch_size = BATCH_SIZE)

then use it as,

for data in train_loader:
    parts, label = data
    out = model.forward(*parts)
    loss = loss_fn(out, label)

Upvotes: 0

Srujan2k21
Srujan2k21

Reputation: 306

Instead of using 3 separate dataLoader elements, you can use a single dataLoader element where each of the datapoint contains 3 separate parts of the image.

Like this:

dataLoader = [[[img1_part1],[img1_part2],[img1_part3], label1], [[img2_part1],[img2_part2],[img2_part3], label2]....]

This way you can use that in training loop as:

for img in dataLoader:
    part1,part2,part3,label = img
    out = model.forward(part1,part2,part3)
    loss = loss_fn(out, label)
    loss.backward()
    optimizer.step()

Upvotes: 4

Related Questions