Reputation: 7
I am trying to implement an augmentation function to my images and masks, I have defined the augmentations like below:
if config.AUG == "PRIMEAugmentation":
augmentations = [autocontrast, equalize, posterize, rotate, solarize, shear_x, shear_y, translate_x, translate_y]
and the function is like below:
import torch
from torch.distributions import Dirichlet, Beta
class PRIMEAugmentation:
def __init__(self, mixture_width=3, mixture_depth=-1):
self.mixture_width = mixture_width
self.mixture_depth = mixture_depth
def __call__(self, x, mask):
x = torch.from_numpy(x).to(torch.float32)
mask = torch.from_numpy(mask)
ws = Dirichlet(torch.ones(self.mixture_width)).sample((x.shape[0],))
m = Beta(torch.ones(1), torch.ones(1)).sample().expand(x.shape[0], 1, 1, 1)
x_aug = torch.zeros_like(x).to(torch.float32)
mask_aug = torch.zeros_like(mask).to(torch.float32)
for i in range(self.mixture_width):
x_i = x.clone()
mask_i = mask.clone()
for d in range(self.mixture_depth):
op = torch.randint(len(self.augmentations), size=(x.shape[0],)).tolist()
x_i, mask_i = self.augmentations[op](x_i, mask_i)
print("ws[:, i] shape:", ws[:, i].shape)
print("x_i shape:", x_i.shape)
print("mask_i shape:", mask_i.shape)
x_aug += ws[:, i][:, None, None] * x_i.to(torch.float32)
mask_aug += ws[:, i][:, None] * mask_i.to(torch.float32)
mixed = (1 - m) * x + m * x_aug.sum(dim=1)
mixed_mask = (1 - m) * mask + m * mask_aug.sum(dim=1)
return mixed.numpy().astype(np.uint8), mixed_mask.numpy().astype(np.uint8)
and I have called it like the following way:
augmenter_PRIMEAugmentation = aug_lib_new.PRIMEAugmentation()
import os
def image_mask_transformation(image,mask,img_trans,aug_trans=False):
transformed = img_trans(image=image, mask=mask)
image = transformed["image"]
mask = transformed["mask"]
if aug_trans in augmenter_list:
image,mask = eval('augmenter_'+aug_trans)(image, mask)
but I am getting an error:
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
ValueError: pic should be 2/3 dimensional. Got 4 dimensions.
full error trace:
Traceback (most recent call last):
File "/home/Crack-PRIME4/main.py", line 422, in <module>
train_logs = train_step(model, optimizer, criteria, trainLoader, accumulation_steps, scaler, epoch, epochs)
File "/home/Crack-PRIME4/main.py", line 240, in train_step
for idx, data in enumerate(bar):
File "/home/anaconda3/envs/myenv/lib/python3.9/site-packages/tqdm/std.py", line 1182, in __iter__
for obj in iterable:
File "/home//anaconda3/envs/myenv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
data = self._next_data()
File "/home/anaconda3/envs/myenv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1345, in _next_data
return self._process_data(data)
File "/home/anaconda3/envs/myenv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1371, in _process_data
data.reraise()
File "/home/anaconda3/envs/myenv/lib/python3.9/site-packages/torch/_utils.py", line 694, in reraise
raise exception
ValueError: Caught ValueError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/anaconda3/envs/myenv/lib/python3.9/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
data = fetcher.fetch(index)
File "/home/anaconda3/envs/myenv/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/anaconda3/envs/myenv/lib/python3.9/site-packages/torch/utils/data/_utils/fetch.py", line 51, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/Crack-PRIME4/tool/dataset.py", line 215, in __getitem__
image_store, mask_store = image_mask_transformation(image, mask, self.img_trans, self.aug_trans)
File "/home/Crack-PRIME4/tool/dataset.py", line 188, in image_mask_transformation
final_image = transforms.ToTensor()(image)
File "/home/anaconda3/envs/myenv/lib/python3.9/site-packages/torchvision/transforms/transforms.py", line 97, in __call__
return F.to_tensor(pic)
File "/home/anaconda3/envs/myenv/lib/python3.9/site-packages/torchvision/transforms/functional.py", line 105, in to_tensor
raise ValueError('pic should be 2/3 dimensional. Got {} dimensions.'.format(pic.ndim))
ValueError: pic should be 2/3 dimensional. Got 4 dimensions.
image-mask transformation is called from:
class SegmentationDataset(Dataset):
def __init__(self, imagePaths, maskPaths, img_trans, aug_trans = False, baug = 1):
self.imagePaths = imagePaths
self.maskPaths = maskPaths
self.img_trans = img_trans
self.aug_trans = aug_trans
self.baug = baug
def __len__(self):
# Number of images
return len(self.imagePaths)
def __getitem__(self, idx):
imagePath = self.imagePaths[idx]
image = cv2.imread(imagePath)
# OpenCV loads an image in the BGR format,
# which we convert to the RGB format
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Read image as grayscale
mask = cv2.imread(self.maskPaths[idx], 0)
######mask = transforms.ToPILImage()(mask)
image_store, mask_store = image_mask_transformation(image, mask, self.img_trans, self.aug_trans)
return image_store, mask_store
Upvotes: 0
Views: 338