Mustafa Uğur Baskın
Mustafa Uğur Baskın

Reputation: 90

How to cluster PyTorch predictions

I'm trying to find road lanes from road images and then make predictions out of the images. So far, I've trained a model that finds road lanes. But most of the predictions are scattered. I'm trying to cluster PyTorch predictions that we get from these road images. These dots are the predictions of model where the road lanes might be.

Predictions shape: [1, 1, 80, 120]

Here's the image of predictions:

enter image description here

Here's what I want to achieve (I edited the image, deleted the dots that are scattered):

enter image description here

As you can see, I deleted the dots (predictions) from the image. I want each dot to be clustered with each other. How can I achieve this? I tried KNN (K Nearest Neighbors) but it didn't work.

Upvotes: 0

Views: 139

Answers (1)

u1234x1234
u1234x1234

Reputation: 2520

If you only want to remove dots then you can try to use morphological operations such as Opening (erode+dilate) to postprocess your mask.

The resulting mask without dots:

enter image description here

Code:

import cv2
import numpy as np

mask = cv2.imread('road_mask.jpg', cv2.IMREAD_GRAYSCALE)
mask = cv2.resize(mask, (120, 80))

mask = cv2.erode(mask, np.ones((2, 2)))
mask = cv2.dilate(mask, np.ones((3, 3)))
mask = ((mask > 10) * 255).astype(np.uint8)

cv2.imwrite("postprocessed_mask.png", mask)

Upvotes: 1

Related Questions