Reputation: 7160
First, yes a similar question has been asked before, but I still haven't been able to solve my specific problem.
I've got a pandas dataframe and a list of labels from a KMeans clustering operation I performed on it. The dataframe is 4-dimensional, and thus doesn't plot in a very friendly manner, so I've been trying to develop a function that will plot a figure of N x N
subplots which are projections (very much like a correlation plot) of two dimensions at a time on each other, colored by the cluster labels. Something kind of (but not exactly) like this:
The problem I am facing is that the subplots will render, the labels, axes and everything looks great, but no data is ever plotted! I've dug through several of the matplotlib examples and my code is modeled after an example from the matplotlib gallery, but I still can't figure out my issue. I've even tried to just plot one at a time and nothing renders. Here is the output I get:
And more importantly, here is my code:
import pandas as pd
import numpy as np
import matplotlib, random
import matplotlib.pyplot as plt
def plot_clusters(data=None, labels=None, seed=500, size="m"):
if data is None or labels is None:
raise Exception('null data')
elif not isinstance(data, (pd.DataFrame, np.ndarray)):
raise Exception('data must be a dataframe or matrix')
elif len(data) < 1:
raise Exception('empty data')
elif not all(isinstance(item, (int, np.int, np.int0, np.int8, np.int16, np.int32, np.int64)) for item in labels):
raise Exception('labels must be list of ints')
elif not isinstance(size, str) and not size.lower() in ['s','m','l']:
raise Exception('size must be a string in the list: ["s","m","l"]')
## Copy data, get dims
plt.style.use('ggplot')
dat = np.copy(data if isinstance(data, np.ndarray) else data.as_matrix())
dims = dat.shape[1]
names = data.columns.values[:] if isinstance(data, pd.DataFrame) else ['dim'+str(i+1) for i in range(dims)]
## Get all the colors, create clusters based on the label
all_colors = [(n,h) for n,h in matplotlib.colors.cnames.iteritems()]
if not seed is None:
random.seed(seed)
random.shuffle(all_colors)
colors = [all_colors[label][1] for label in labels] ## Get colors assigned by label factor levels
## Set up axes
multi = 1 if size.lower() == 's' else 1.5 if size.lower() == 'm' else 2
fig, axes = plt.subplots(figsize=reduce(lambda x,y: tuple([x*multi,y*multi]),(8,6)), nrows=dims, ncols=dims) ## Must be NxN
## Now loop
idx_ct = 0
columns = [[r[d] for r in dat] for d in range(dims)]
y_min, y_max = np.min(dat), np.max(dat)
for i in range(dims):
for j in range(dims):
axes[i,j].plot(x=np.array(columns[i]), y=np.array(columns[j]), color=colors)
axes[i,j].set_ylim([y_min, y_max])
axes[i,j].set_xlim([y_min, y_max])
axes[i,j].margins(0)
## Set the labels on the y-axis
if j == 0: ## Only the left-most col gets the label
axes[i,j].set_ylabel(names[i])
if j == dims-1:
axes[i,j].get_yaxis().tick_right()
else:
axes[i,j].get_yaxis().set_ticklabels([])
## Set the labels on the X-axis
if i == 0: ## Only the top-most row gets the label
axes[i,j].set_title(names[j])
if i < dims-1:
axes[i,j].get_xaxis().set_ticklabels([])
idx_ct += 1
fig.tight_layout()
plt.show()
## Plot the clusters
plot_clusters(data=topic_maps_n, labels=labels, size='l')
Note that data
can either be a pandas DataFrame or a numpy matrix, and labels
is a list the same length as the dataframe (it is comprised of only integers). An example of what data
may look like:
And for the sake of reproducability:
data = pandas.DataFrame.from_records(numpy.array([[2.44593742e-01, 4.18387124e-02, 1.56175780e-02, 5.15885742e-04],
[3.38941458e-01, 8.61882075e-02, 2.51219235e-02, 1.29532576e-03],
[6.79218190e-02, 2.14741500e-02, 4.51219203e-03, 1.53073947e-06],
[5.24470045e-01, 1.65668947e-01, 2.11256296e-02, 1.03752391e-04],
[5.93903485e-01, 1.48081357e-01, 5.18316207e-02, 4.03474064e-02]]),
columns=['Topic1Magnitude','Topic2Magnitude','Topic3Magnitude','Topic4Magnitude'])
And labels could simply (arbitrarily) be:
labels = [0,0,1,0,1]
If anyone can help me identify where I'm going wrong, I would be very appreciative.
Upvotes: 3
Views: 1745
Reputation: 7160
In case anyone else ever faces this, I solved the issue. In order to get the data to display, I had to change axes[i,j].plot(...)
to axes[i,j].scatter(...)
, which I could've sworn I'd tried before:
Upvotes: 1