Shawn Zhang
Shawn Zhang

Reputation: 21

Pytorch Problem with Custom Dataset Class

First, I made a custom dataset to load in images from my dataframe (containing the image filepath and corresponding int label):

class Dataset(torch.utils.data.Dataset):

    def __init__(self, dataframe, transform=None):
        self.frame = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        filename = self.frame.iloc[idx, 0]
        image = torch.from_numpy(io.imread(filename).transpose((2, 0, 1))).float()
        label = self.frame.iloc[idx, 1]
        sample = {'image': image, 'label': label}
        if self.transform:
            sample = self.transform(sample)
        return sample

Then, I use pre-existing model architecture like so:

model = models.densenet161()
num_ftrs = model.classifier.in_features
model.classifier = nn.Linear(num_ftrs, 10)  # where 10 is my number of classes

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

Finally, for training, I do the following:

model.train()  # switch to train mode
        
for epoch in range(5):
    for i, sample in enumerate(train_set):  # where train_set is an instance of my Dataset class
        optimizer.zero_grad()
        image, label = sample['image'].unsqueeze(0), torch.Tensor(sample['label']).long()
        output = model(image)

        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

However, I am experiencing errors with loss = criterion(output, label). It tells me that ValueError: Expected input batch_size (1) to match target batch_size (2).. Can someone teach me how to properly use a custom dataset, especially with loading in batches of data? Also, why am I experiencing that ValueError? Thank you!

Upvotes: 0

Views: 162

Answers (1)

Ritchie Tan
Ritchie Tan

Reputation: 1

please check the following lines:

label = self.frame.iloc[idx, 1] in dataset defination, you may print this to re-check, is this return two int

image, label = sample['image'].unsqueeze(0), torch.Tensor(sample['label']).long() in training code, you need to check the shape of the tensor

Upvotes: 0

Related Questions