Barden
Barden

Reputation: 1140

message does not fit sklearn k-means convergence implementation

In order to re-use the convergence criterion for k-means as implemented in scikit-learn KMeans for my tensorflow-based k-means implementation I need to understand it, but made this observation which I would love to have explained:

KMeans converges with this message:

Iteration 45, inertia 6.993125 center shift 2.610457e-03 within tolerance 8.374284e-06

The implementation in https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/cluster/_k_means.py (line 442ff, function _kmeans_single_lloyd) is as follows:

center_shift_total = squared_norm(centers_old - centers)
if center_shift_total <= tol:
    if verbose:
        print("Converged at iteration %d: "
              "center shift %e within tolerance %e"
              % (i, center_shift_total, tol))
    break

The message should be printed only if the value of center_shift_total is smaller or equal than the value of tolerance. As you can see from the output this is not the case in my run of KMeans (center_shift_total is in fact much larger than tol).

How can this happen (or what am I overlooking)? I noted that the "Converged at iteration" part is missing as well, but the observed message definitely makes no sense to me.

Upvotes: 1

Views: 278

Answers (1)

iliar
iliar

Reputation: 952

I found it. Go to the file: _k_means_elkan.pyx line 243 (in 0.23.1). In the master branch it would be line 245.

        if verbose:
            print('Iteration %i, inertia %s'
                    % (iteration, np.sum((X_ - centers_[labels]) ** 2 *
                                         sample_weight[:,np.newaxis])))
        center_shift_total = np.sum(center_shift)
        if center_shift_total ** 2 < tol:
            if verbose:
                print("center shift %e within tolerance %e"
                      % (center_shift_total, tol))
            break

It seems that it's checking the square of center_shift_total. While inside k_means_.py it's checking for center_shift_total without squaring it.

Upvotes: 1

Related Questions