Reputation: 11340
I have a tensor composed of 9 channel [9, 224, 224], (which is result of prediction. How could I convert to 3 channel as an image so that I could display it.
predicted =predicted.cpu()
label=predicted [0]
print(label.shape)
torch.Size([9, 224, 224])
Upvotes: 0
Views: 338
Reputation: 2510
I'm assuming that your (9, 224, 224) data is semantic segmentation maps. There are two possible variants:
# find normalized probabilities that sums up to 1 across the classes
prediction = prediction.softmax(dim=0).cpu().numpy()
# find the most probable class for each pixel
labels = prediction.argmax(axis=0)
# create a color pallete that maps class_idx to (R, G, B)
palette = np.random.randint(0, 255, (prediction.shape[0], 3), np.uint8)
color_mask = np.zeros((*r.shape, 3), np.uint8)
# map each label to (RGB) color
for idx, color in enumerate(palette):
color_mask[r == idx] = color
cv2.imshow('color_mask', color_mask)
cv2.waitKey()
Example of visualization:
# prediction = torch.sigmoid(prediction) # in the case of logits
# convert 0-1 probability maps into 0-255
prediction = (prediction * 255).astype(np.uint8)
# stack multiple probability maps horizontally
prediction = np.hstack(prediction)
Image taken from https://www.publicdomainpictures.net/en/view-image.php?image=24076
Upvotes: 2