Cepera
Cepera

Reputation: 97

Pytorch: File-specific action for each image in the batch

I have a dataset of images each of them having an additional attribute "channel_no". Each image should be processed with the nn layer according to its channel_no:

 images with channel_no=1 have to be processed with layer1
 images with channel_no=2 have to be processed with layer2
 images with channel_no=3 have to be processed with layer3
etc...

The problem is that when the batch contains more than one image, forward() function gets a torch tensor with the batch of images as input, and each of the images has different channel_no. So it is not clear how to process each image separately.

Here is the code for the case when the batch has 1 image only:

class Net(nn.Module):
    def __init__ (self, weight):
        super(Net, self).__init__()

        self.layer1 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.layer2 = nn.Linear(hidden_sizes[0], hidden_sizes[1])
        self.layer3 = nn.Linear(hidden_sizes[0], hidden_sizes[1])

        self.outp = nn.Linear(hidden_sizes[1], output_size)
        
    def forward(self, x, channel_no):
        channel_no = channel_no[0] #extract channel_no from the batch list

        x = x.view(-1,hidden_sizes[0])

        if channel_no == 1: x = F.relu(self.layer1(x))
        if channel_no == 2: x = F.relu(self.layer2(x))
        if channel_no == 3: x = F.relu(self.layer3(x))

        x = torch.sigmoid(self.outp(x))

        return x    

Is it possible to process each image separately using batch size > 1 ?

Upvotes: 0

Views: 626

Answers (1)

user1389840
user1389840

Reputation: 679

To process images separately you probably need separate tensors. I'm not sure if there's a fast way to do it, but you could split the tensor in the batch dimension to get individual image tensors and then iterate through them to sort them by channel number. Then join each group of images with the same channel number into a new tensor and process that tensor specially.

Upvotes: 1

Related Questions