Julianno Sambatti
Julianno Sambatti

Reputation: 119

How does PyTorch handle labels when loading image/mask files for image segmentation?

I am starting an image segmentation project using PyTorch. I have a reduced dataset in a folder and 2 subfolders - "image" to store the images and "mask" for the masked images. Images and masks are .png files with 3 channels and 256x256 pixels. Because it is image segmentation, the labelling has to be performed a pixel by pixel. I am working only with 2 classes at the moment for simplicity. So far, I achieved the following:

I was able to load my files into classes "images" or "masks" by

root_dir="./images_masks"
train_ds_untransf = torchvision.datasets.ImageFolder(root=root_dir)
train_ds_untransf.classes
Out[621]:
['images', 'masks']  

and transform the data into tensors

from torchvision import transforms
train_trans = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.ImageFolder(root=root_dir,transform=train_trans)

Each tensor in this "train_dataset" has the following shape:

train_dataset[1][0].shape
torch.Size([3, 256, 256])

Now I need to feed the loaded data into the CNN model, and have explored the PyTorch DataLoader for this

train_dataloaded = DataLoader(train_dataset, batch_size=2, shuffle=False, num_workers=4)

I use the following code to check the resulting tensor's shape

for x, y in train_dl:
    print (x.shape)
    print (y.shape)
    print(y)

and get

torch.Size([2, 3, 256, 256])
torch.Size([2])
tensor([0, 0])
torch.Size([2, 3, 256, 256])
torch.Size([2])
tensor([0, 1])
.
.
.

Shapes seem correct. However, the first problem is that I got tensors of the same folder, indicated by some "y" tensors with the same value [0, 0]. I would expect that they all are [1, 0]: 1 representing image, 0 representing masks.

The second problem is that, although the documentation is clear when labels are entire images, it is not clear as to how to apply it for labeling at the pixel level, and I am certain the labels are not correct.

What would be an alternative to correctly label this dataset?

thank you

Upvotes: 1

Views: 5245

Answers (1)

Shai
Shai

Reputation: 114876

The class torchvision.datasets.ImageFolder is designed for image classification problems, and not for segmentation; therefore, it expects a single integer label per image and the label is determined by the subfolder in which the images are stored. So, as far as your dataloader concern you have two classes of images "images" and "masks" and your net tries to distinguish between them.

What you actually need is a different implementation of dataset that for each __getitem__ return an image and the corresponding mask. You can see examples of such classes here.

Additionally, it is a bit weird that your binary pixel-wise labels are stored as 3 channel image. Segmentation masks are usually stored as a single channel image.

Upvotes: 4

Related Questions