Reputation: 103
so I have images with the format(width,height,channel). My original channel is in rgb. So I Load the images in grayscale
for r, d, file in tqdm(os.walk(path)):
for i in tqdm(file):
if i[0:2]=="01":
dist_one.append(cv2.imread(os.path.join(path,i),cv2.IMREAD_GRAYSCALE))
else:
dist_two.append(cv2.imread(os.path.join(path,i),cv2.IMREAD_GRAYSCALE))
Suppose the images have the shape (187, 187). So I add a channel using the code
g = np.expand_dims(dist_one[0], axis=0)
But this breaks the images when I try to plot the image.
TypeError Traceback (most recent call last) in () 1 import matplotlib.pyplot as plt 2 ----> 3 plt.imshow(dist_one[0],cmap='gray') 4 plt.show() 5 frames/usr/local/lib/python3.6/dist-packages/matplotlib/image.py in set_data(self, A) 688 or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]): 689 raise TypeError("Invalid shape {} for image data" --> 690
.format(self._A.shape)) 691 692 if self._A.ndim == 3: TypeError: Invalid shape (1, 187, 187) for image data
But it works when the channel is put last.
g = np.expand_dims(dist_one[0], axis=-1).
Whats the reason for this?
I need the channel at the first for pytorch. Or am I suppose to train the model with broken images?
Upvotes: 0
Views: 189
Reputation: 116
Pytorch requires [C, H, W], whereas numpy and matplotlib (which is where the error is being thrown) require the image to be [H, W, C]. To fix this issue, I'd suggest plotting g obtained by doing this g = np.expand_dims(dist_one[0], axis=-1)
. Before sending it to PyTorch, you can do one of two things. You can use 'g = torch.tensor(g).permute(2, 0, 1)` which will results in a [C, H, W] tensor or you can use the PyTorch Dataset and Dataloader setup, which handles other things like batching for you.
Upvotes: 0