Reputation: 27
I have the following class:
class MyKMeans:
def __init__(self, max_iter = 300):
self.max_iter = max_iter
# Directly access
self.centroids = None
self.clusters = None
def fit(self, X, k):
"""
"""
# each point is assigned to a cluster
clusters = np.zeros(X.shape[0])
# select k random centroids
random_idxs = np.random.choice(len(X), size=k, replace=False)
centroids = X[random_idxs, :]
# iterate until no change occurs in centroids
while True:
# for each point
for i, point in enumerate(X):
min_d = float('inf')
# find the closest centroid to the point
for idx, centroid in enumerate(centroids):
d = euclidean_dist(centroid, point)
if d < min_d:
min_d = d
clusters[i] = idx
# update the new centroids by averaging the points in each cluster
new_centroids = pd.DataFrame(X).groupby(by=clusters).mean().values
# if the centroids didn't change, then stop
if np.count_nonzero(centroids-new_centroids) == 0:
break
# otherwise, update the centroids
else:
centroids = new_centroids
self.centroids = centroids
self.clusters = clusters
and run it using
k = 4
kmeans = MyKMeans()
kmeans.fit(X, k)
centroids, clusters = kmeans.centroids, kmeans.clusters
However, this takes usually 5 seconds to complete running. On the other hand, if I move the method to a new function,
def fit(X, k):
"""
"""
# each point is assigned to a cluster
clusters = np.zeros(X.shape[0])
# select k random centroids
random_idxs = np.random.choice(len(X), size=k, replace=False)
centroids = X[random_idxs, :]
# iterate until no change occurs in centroids
while True:
# for each point
for i, point in enumerate(X):
min_d = float('inf')
# find the closest centroid to the point
for idx, centroid in enumerate(centroids):
d = euclidean_dist(centroid, point)
if d < min_d:
min_d = d
clusters[i] = idx
# update the new centroids by averaging the points in each cluster
new_centroids = pd.DataFrame(X).groupby(by=clusters).mean().values
# if the centroids didn't change, then stop
if np.count_nonzero(centroids-new_centroids) == 0:
break
# otherwise, update the centroids
else:
centroids = new_centroids
return centroids, clusters
and get the same variables by calling centroids, clusters = fit(X, k)
, the runtime is around 0.5-1 second which is a big difference.
Is there a reason why simply having a class method instead of a function causes such a big difference in runtime, and is there any way to improve the runtime while still being able to use the class?
Upvotes: 1
Views: 346
Reputation: 80031
The return statement in your non-class version is inside of the while loop so it exits the loop early.
Upvotes: 2