Zynk
Zynk

Reputation: 3025

Gaussian Mixture Model fit in Python with sklearn is too slow - Any alternative?

I need to use Gaussian Mixture Models on an RGB image, and therefore the dataset is quite big. This needs to run on real time (from a webcam feed). I first coded this with Matlab and I was able to achieve a running time of 0.5 seconds for an image of 1729 × 866. The images for the final application will be smaller and therefore the timing will be faster.

However, I need to implement this with Python and OpenCV for the final application (I need it to run on an embedded board). I translated all my code and used sklearn.mixture.GMM to replace fitgmdist in Matlab. The line of code calculating the GMM model itself is performed in only 7.7e-05 seconds, but the one to fit the model takes 19 seconds. I have tried other types of covariance, such as 'diag' or 'spherical', and the time does reduce a little but the results are worse and the time is still not good enough, not even close.

I was wondering if there is any other library I can use, or if it would be worth it to translate the functions from Matlab to Python.

Here is my example:

import cv2
import numpy as np
import math
from sklearn.mixture import GMM

im = cv2.imread('Boat.jpg');

h, w, _ = im.shape;       # Height and width of the image

# Extract Blue, Green and Red
imB = im[:,:,0]; imG = im[:,:,1]; imR = im[:,:,2];

# Reshape Blue, Green and Red channels into single-row vectors
imB_V = np.reshape(imB, [1, h * w]);
imG_V = np.reshape(imG, [1, h * w]);
imR_V = np.reshape(imR, [1, h * w]);

# Combine the 3 single-row vectors into a 3-row matrix
im_V =  np.vstack((imR_V, imG_V, imB_V));

# Calculate the bimodal GMM
nmodes = 2;
GMModel = GMM(n_components = nmodes, covariance_type = 'full', verbose = 0, tol = 1e-3)
GMModel = GMModel.fit(np.transpose(im_V))

Thank you very much for your help

Upvotes: 2

Views: 4734

Answers (1)

Michael D
Michael D

Reputation: 1767

You can try fit with the 'diagonal' or spherical covariance matrix instead of full. covariance_type='diag' or covariance_type='spherical'

I believe it will be much faster.

Upvotes: 2

Related Questions