GKC
GKC

Reputation: 479

Representing cluster centers in Plotly Express's px.scatter_3d()

I am trying to perform clustering (i.e., clusters and centroids) using plotly express as highlighted here. Following this example I am able to represent the different clusters. However, I seem to struggle to find information on how I can also represent the cluster centroids on the generated clusters using plotly express.

Dataset:

column_a,column_b,column_c,column_d
5.1,3.5,1.4,0.2
4.9,3.0,1.4,0.2
4.7,3.2,1.3,0.2
4.6,3.1,1.5,0.2
5.0,3.6,1.4,0.2
5.4,3.9,1.7,0.4
4.6,3.4,1.4,0.3
5.0,3.4,1.5,0.2
4.4,2.9,1.4,0.2
4.9,3.1,1.5,0.1
5.4,3.7,1.5,0.2
4.8,3.4,1.6,0.2
4.8,3.0,1.4,0.1
4.3,3.0,1.1,0.1
5.8,4.0,1.2,0.2
5.7,4.4,1.5,0.4
5.4,3.9,1.3,0.4
5.1,3.5,1.4,0.3
5.7,3.8,1.7,0.3
5.1,3.8,1.5,0.3
5.4,3.4,1.7,0.2
5.1,3.7,1.5,0.4
4.6,3.6,1.0,0.2
5.1,3.3,1.7,0.5
4.8,3.4,1.9,0.2
5.0,3.0,1.6,0.2
5.0,3.4,1.6,0.4
5.2,3.5,1.5,0.2
5.2,3.4,1.4,0.2
4.7,3.2,1.6,0.2
4.8,3.1,1.6,0.2
5.4,3.4,1.5,0.4
5.2,4.1,1.5,0.1
5.5,4.2,1.4,0.2
4.9,3.1,1.5,0.1
5.0,3.2,1.2,0.2
5.5,3.5,1.3,0.2
4.9,3.1,1.5,0.1
4.4,3.0,1.3,0.2
5.1,3.4,1.5,0.2
5.0,3.5,1.3,0.3
4.5,2.3,1.3,0.3
4.4,3.2,1.3,0.2
5.0,3.5,1.6,0.6
5.1,3.8,1.9,0.4
4.8,3.0,1.4,0.3
5.1,3.8,1.6,0.2
4.6,3.2,1.4,0.2
5.3,3.7,1.5,0.2
5.0,3.3,1.4,0.2
7.0,3.2,4.7,1.4
6.4,3.2,4.5,1.5
6.9,3.1,4.9,1.5
5.5,2.3,4.0,1.3
6.5,2.8,4.6,1.5
5.7,2.8,4.5,1.3
6.3,3.3,4.7,1.6
4.9,2.4,3.3,1.0
6.6,2.9,4.6,1.3
5.2,2.7,3.9,1.4
5.0,2.0,3.5,1.0
5.9,3.0,4.2,1.5
6.0,2.2,4.0,1.0
6.1,2.9,4.7,1.4
5.6,2.9,3.6,1.3
6.7,3.1,4.4,1.4
5.6,3.0,4.5,1.5
5.8,2.7,4.1,1.0
6.2,2.2,4.5,1.5
5.6,2.5,3.9,1.1
5.9,3.2,4.8,1.8
6.1,2.8,4.0,1.3
6.3,2.5,4.9,1.5
6.1,2.8,4.7,1.2
6.4,2.9,4.3,1.3
6.6,3.0,4.4,1.4
6.8,2.8,4.8,1.4
6.7,3.0,5.0,1.7
6.0,2.9,4.5,1.5
5.7,2.6,3.5,1.0
5.5,2.4,3.8,1.1
5.5,2.4,3.7,1.0
5.8,2.7,3.9,1.2
6.0,2.7,5.1,1.6
5.4,3.0,4.5,1.5
6.0,3.4,4.5,1.6
6.7,3.1,4.7,1.5
6.3,2.3,4.4,1.3
5.6,3.0,4.1,1.3
5.5,2.5,4.0,1.3
5.5,2.6,4.4,1.2
6.1,3.0,4.6,1.4
5.8,2.6,4.0,1.2
5.0,2.3,3.3,1.0
5.6,2.7,4.2,1.3
5.7,3.0,4.2,1.2
5.7,2.9,4.2,1.3
6.2,2.9,4.3,1.3
5.1,2.5,3.0,1.1
5.7,2.8,4.1,1.3
6.3,3.3,6.0,2.5
5.8,2.7,5.1,1.9
7.1,3.0,5.9,2.1
6.3,2.9,5.6,1.8
6.5,3.0,5.8,2.2
7.6,3.0,6.6,2.1
4.9,2.5,4.5,1.7
7.3,2.9,6.3,1.8
6.7,2.5,5.8,1.8
7.2,3.6,6.1,2.5
6.5,3.2,5.1,2.0
6.4,2.7,5.3,1.9
6.8,3.0,5.5,2.1
5.7,2.5,5.0,2.0
5.8,2.8,5.1,2.4
6.4,3.2,5.3,2.3
6.5,3.0,5.5,1.8
7.7,3.8,6.7,2.2
7.7,2.6,6.9,2.3
6.0,2.2,5.0,1.5
6.9,3.2,5.7,2.3
5.6,2.8,4.9,2.0
7.7,2.8,6.7,2.0
6.3,2.7,4.9,1.8
6.7,3.3,5.7,2.1
7.2,3.2,6.0,1.8
6.2,2.8,4.8,1.8
6.1,3.0,4.9,1.8
6.4,2.8,5.6,2.1
7.2,3.0,5.8,1.6
7.4,2.8,6.1,1.9
7.9,3.8,6.4,2.0
6.4,2.8,5.6,2.2
6.3,2.8,5.1,1.5
6.1,2.6,5.6,1.4
7.7,3.0,6.1,2.3
6.3,3.4,5.6,2.4
6.4,3.1,5.5,1.8
6.0,3.0,4.8,1.8
6.9,3.1,5.4,2.1
6.7,3.1,5.6,2.4
6.9,3.1,5.1,2.3
5.8,2.7,5.1,1.9
6.8,3.2,5.9,2.3
6.7,3.3,5.7,2.5
6.7,3.0,5.2,2.3
6.3,2.5,5.0,1.9
6.5,3.0,5.2,2.0
6.2,3.4,5.4,2.3
5.9,3.0,5.1,1.8

So far this is what I have tried:

import numpy as np
import pandas as pd
from sklearn.cluster import KMeans
import plotly.express as px
from mpl_toolkits.mplot3d import Axes3D
import streamlit as st

df = pd.read_csv('iris.csv')
df_columns = ['column_a', 'column_b', 'column_c']


kmeans = KMeans(n_clusters=3, init = 'k-means++', max_iter=200)
km = kmeans.fit(df[df_columns])
centroids = km.cluster_centers_
cluster_labels = km.labels_
df['cluster'] = pd.Series(cluster_labels, index=df.index)

fig=px.scatter_3d(df, color=cluster_labels, labels={'color': 'cluster'})
#add code for the centroids
st.plotly_chart(fig)

Using matplotlib I can easily achieve the results as below. However, I need to use scatter_3d from plotly express

fig=plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(df.iloc[:, 0].values, df.iloc[:, 1].values, df.iloc[:, 2].values, c=cluster_labels , cmap='viridis')
ax.scatter(centroids[:, 0], centroids[:, 1], centroids[:, 2],s = 20, c = 'black', marker='*')

How can I represent the cluster centers using plotly express's px.scatter_3d()

Upvotes: 2

Views: 1391

Answers (1)

Yolina
Yolina

Reputation: 1

I had the same problem and the only solution I found was to use Scatter3d.

import plotly.graph_objects as go

For a 3D plot you can use:

fig.add_trace(go.Scatter3d(x=centroids[:,0], y=centroids[:,1], z=centroids[:,2],  text="Centroid", mode='markers', marker=dict(size=10, color='black')))

And for a 2D plot:

fig.add_trace(go.Scatter(x=centroids[:,0], y=centroids[:,1],  text="Centroid",mode='markers', marker=dict(size=10, color='black')))

Upvotes: 0

Related Questions