user836026
user836026

Reputation: 11340

Convert torch of 9 channel to image of 3 channel (or 1) to display it

I have a tensor composed of 9 channel [9, 224, 224], (which is result of prediction. How could I convert to 3 channel as an image so that I could display it.

predicted =predicted.cpu() 
label=predicted [0]
print(label.shape)

torch.Size([9, 224, 224])

Upvotes: 0

Views: 338

Answers (1)

u1234x1234
u1234x1234

Reputation: 2510

I'm assuming that your (9, 224, 224) data is semantic segmentation maps. There are two possible variants:

  1. You have multi-class predictions
# find normalized probabilities that sums up to 1 across the classes
prediction = prediction.softmax(dim=0).cpu().numpy()

# find the most probable class for each pixel
labels = prediction.argmax(axis=0)

# create a color pallete that maps class_idx to (R, G, B)
palette = np.random.randint(0, 255, (prediction.shape[0], 3), np.uint8)
color_mask = np.zeros((*r.shape, 3), np.uint8)
# map each label to (RGB) color
for idx, color in enumerate(palette):
    color_mask[r == idx] = color

cv2.imshow('color_mask', color_mask)
cv2.waitKey()

Example of visualization:

enter image description here enter image description here

  1. You have multi-label predictions. In that case you have 9 independent prediction masks
# prediction = torch.sigmoid(prediction)  # in the case of logits
# convert 0-1 probability maps into 0-255
prediction = (prediction * 255).astype(np.uint8)
# stack multiple probability maps horizontally
prediction = np.hstack(prediction)

Example: enter image description here

Image taken from https://www.publicdomainpictures.net/en/view-image.php?image=24076

Upvotes: 2

Related Questions