JayJona
JayJona

Reputation: 502

How can I efficiently plot a distance matrix using seaborn?

So I have a dataset of more ore less 11.000 records, with 4 features all them are discrete or continue. I perform clustering using K-means, then I add the column "cluster" to the dataframe using kmeans.labels_. Now I want to plot the distance matrix so I used pdist from scipy, but the matrix is not plotted.

Here is my code.

from scipy.spatial.distance import pdist
from scipy.spatial.distance import squareform
import gc

# distance matrix
def distance_matrix(df_labeled, metric="euclidean"):
    df_labeled.sort_values(by=['cluster'], inplace=True)
    dist = pdist(df_labeled, metric)
    dist = squareform(dist)    
    sns.heatmap(dist, cmap="mako")
    print(dist)
    del dist
    gc.collect()

distance_matrix(finalDf)

Output:

[[ 0.          2.71373462  3.84599479 ...  7.59910903  8.10265588
   8.27195104]
 [ 2.71373462  0.          2.94410672 ...  7.90444283  8.28225031
   8.48094661]
 [ 3.84599479  2.94410672  0.         ...  9.78706347 10.42014451
  10.61261498]
 ...
 [ 7.59910903  7.90444283  9.78706347 ...  0.          1.27795469
   1.44711258]
 [ 8.10265588  8.28225031 10.42014451 ...  1.27795469  0.
   0.52333107]
 [ 8.27195104  8.48094661 10.61261498 ...  1.44711258  0.52333107
   0.        ]]

I get the following graph:
enter image description here

As you can see, the plot is empty. Also I have to free up some RAM because google colab crashes.

How can I solve the problem?

Upvotes: 2

Views: 785

Answers (2)

Ali
Ali

Reputation: 20

To plot a distance matrix using seaborn, you can use the seaborn.heatmap() function.

Upvotes: 0

J_H
J_H

Reputation: 20450

The original question was well-phrased but was not a reprex. Its code, at least the part we can see, appears to work fine.

Here is a demo of producing a heatmap for another dataset that also has 11 K rows.

from scipy.spatial.distance import pdist, squareform
from uszipcode import SearchEngine
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns


def distance_matrix(df: pd.DataFrame, metric="euclidean"):
    df = df[["zipcode", "lat", "lng", "population_density"]]
    df = df.sort_values(by=["zipcode"])
    print(df)
    dist = pdist(df, metric)
    dist = squareform(dist)
    sns.heatmap(dist, cmap="mako")
    print(dist)
    plt.show()


def get_df() -> pd.DataFrame:
    zips = SearchEngine().by_population_density(lower=100, returns=11_000)
    df = pd.DataFrame(z.to_dict() for z in zips)
    df["zipcode"] = df.zipcode.astype(int)
    return df


distance_matrix(get_df())

It consumes at least ten GiB under MacOS 12.6.2, using cPython 3.10.8, matplotlib 3.6.2, scipy 1.9.3, seaborn 0.12.1.

It displays this: heatmap

Upvotes: 1

Related Questions