
Reputation: 7

ValueError: pic should be 2/3 dimensional. Got 4 dimensions

I am trying to implement an augmentation function to my images and masks, I have defined the augmentations like below:

if config.AUG == "PRIMEAugmentation":
    augmentations = [autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, translate_x, translate_y]

and the function is like below:

import torch
from torch.distributions import Dirichlet, Beta

class PRIMEAugmentation:
    def __init__(self, mixture_width=3, mixture_depth=-1):
        self.mixture_width = mixture_width
        self.mixture_depth = mixture_depth

    def __call__(self, x, mask):
        x = torch.from_numpy(x).to(torch.float32)
        mask = torch.from_numpy(mask)
        ws = Dirichlet(torch.ones(self.mixture_width)).sample((x.shape[0],))
        m = Beta(torch.ones(1), torch.ones(1)).sample().expand(x.shape[0], 1, 1, 1)

        x_aug = torch.zeros_like(x).to(torch.float32)
        mask_aug = torch.zeros_like(mask).to(torch.float32)
        for i in range(self.mixture_width):
            x_i = x.clone()
            mask_i = mask.clone()
            for d in range(self.mixture_depth):
                op = torch.randint(len(self.augmentations), size=(x.shape[0],)).tolist()
                x_i, mask_i = self.augmentations[op](x_i, mask_i)
            print("ws[:, i] shape:", ws[:, i].shape)
            print("x_i shape:", x_i.shape)
            print("mask_i shape:", mask_i.shape)
            x_aug += ws[:, i][:, None, None] *
            mask_aug += ws[:, i][:, None] *

        mixed = (1 - m) * x + m * x_aug.sum(dim=1)
        mixed_mask = (1 - m) * mask + m * mask_aug.sum(dim=1)
        return mixed.numpy().astype(np.uint8), mixed_mask.numpy().astype(np.uint8)

and I have called it like the following way:

augmenter_PRIMEAugmentation = aug_lib_new.PRIMEAugmentation()

import os

def image_mask_transformation(image,mask,img_trans,aug_trans=False):
    transformed = img_trans(image=image, mask=mask)
    image = transformed["image"]
    mask = transformed["mask"]

    if aug_trans in augmenter_list:
        image,mask = eval('augmenter_'+aug_trans)(image, mask)

but I am getting an error:

    raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
ValueError: pic should be 2/3 dimensional. Got 4 dimensions.

full error trace:

Traceback (most recent call last):
  File "/home/Crack-PRIME4/", line 422, in <module>
    train_logs = train_step(model, optimizer, criteria, trainLoader, accumulation_steps, scaler, epoch, epochs)
  File "/home/Crack-PRIME4/", line 240, in train_step
    for idx, data in enumerate(bar):
  File "/home/anaconda3/envs/myenv/lib/python3.9/site-packages/tqdm/", line 1182, in __iter__
    for obj in iterable:
  File "/home//anaconda3/envs/myenv/lib/python3.9/site-packages/torch/utils/data/", line 630, in __next__
    data = self._next_data()
  File "/home/anaconda3/envs/myenv/lib/python3.9/site-packages/torch/utils/data/", line 1345, in _next_data
    return self._process_data(data)
  File "/home/anaconda3/envs/myenv/lib/python3.9/site-packages/torch/utils/data/", line 1371, in _process_data
  File "/home/anaconda3/envs/myenv/lib/python3.9/site-packages/torch/", line 694, in reraise
    raise exception
ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/anaconda3/envs/myenv/lib/python3.9/site-packages/torch/utils/data/_utils/", line 308, in _worker_loop
    data = fetcher.fetch(index)
  File "/home/anaconda3/envs/myenv/lib/python3.9/site-packages/torch/utils/data/_utils/", line 51, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/anaconda3/envs/myenv/lib/python3.9/site-packages/torch/utils/data/_utils/", line 51, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/home/Crack-PRIME4/tool/", line 215, in __getitem__
    image_store, mask_store = image_mask_transformation(image, mask, self.img_trans, self.aug_trans)
  File "/home/Crack-PRIME4/tool/", line 188, in image_mask_transformation
    final_image = transforms.ToTensor()(image)
  File "/home/anaconda3/envs/myenv/lib/python3.9/site-packages/torchvision/transforms/", line 97, in __call__
    return F.to_tensor(pic)
  File "/home/anaconda3/envs/myenv/lib/python3.9/site-packages/torchvision/transforms/", line 105, in to_tensor
    raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
ValueError: pic should be 2/3 dimensional. Got 4 dimensions.

image-mask transformation is called from:

class SegmentationDataset(Dataset):
    def __init__(self, imagePaths, maskPaths, img_trans, aug_trans = False, baug = 1):
        self.imagePaths = imagePaths
        self.maskPaths = maskPaths
        self.img_trans = img_trans
        self.aug_trans = aug_trans
        self.baug = baug

    def __len__(self):
        # Number of images
        return len(self.imagePaths)

    def __getitem__(self, idx):
        imagePath = self.imagePaths[idx]

        image = cv2.imread(imagePath)
        # OpenCV loads an image in the BGR format,
        # which we convert to the RGB format
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # Read image as grayscale
        mask = cv2.imread(self.maskPaths[idx], 0)
        ######mask = transforms.ToPILImage()(mask)

        image_store, mask_store = image_mask_transformation(image, mask, self.img_trans, self.aug_trans)
        return image_store, mask_store

Upvotes: 0

Views: 338

Answers (0)

Related Questions