TayTay
TayTay

Reputation: 7160

Matplotlib subplots created in loop do not display data

First, yes a similar question has been asked before, but I still haven't been able to solve my specific problem.

I've got a pandas dataframe and a list of labels from a KMeans clustering operation I performed on it. The dataframe is 4-dimensional, and thus doesn't plot in a very friendly manner, so I've been trying to develop a function that will plot a figure of N x N subplots which are projections (very much like a correlation plot) of two dimensions at a time on each other, colored by the cluster labels. Something kind of (but not exactly) like this:

Iris data plot

The problem I am facing is that the subplots will render, the labels, axes and everything looks great, but no data is ever plotted! I've dug through several of the matplotlib examples and my code is modeled after an example from the matplotlib gallery, but I still can't figure out my issue. I've even tried to just plot one at a time and nothing renders. Here is the output I get:

My improper output

And more importantly, here is my code:

import pandas as pd
import numpy as np
import matplotlib, random
import matplotlib.pyplot as plt
def plot_clusters(data=None, labels=None, seed=500, size="m"):
    if data is None or labels is None:
        raise Exception('null data')
    elif not isinstance(data, (pd.DataFrame, np.ndarray)):
        raise Exception('data must be a dataframe or matrix')
    elif len(data) < 1:
        raise Exception('empty data')
    elif not all(isinstance(item, (int, np.int, np.int0, np.int8, np.int16, np.int32, np.int64)) for item in labels):
        raise Exception('labels must be list of ints')
    elif not isinstance(size, str) and not size.lower() in ['s','m','l']:
        raise Exception('size must be a string in the list: ["s","m","l"]')

    ## Copy data, get dims
    plt.style.use('ggplot')
    dat = np.copy(data if isinstance(data, np.ndarray) else data.as_matrix())
    dims = dat.shape[1]
    names = data.columns.values[:] if isinstance(data, pd.DataFrame) else ['dim'+str(i+1) for i in range(dims)]

    ## Get all the colors, create clusters based on the label
    all_colors = [(n,h) for n,h in matplotlib.colors.cnames.iteritems()]
    if not seed is None:
        random.seed(seed)
    random.shuffle(all_colors)
    colors = [all_colors[label][1] for label in labels] ## Get colors assigned by label factor levels

    ## Set up axes
    multi = 1 if size.lower() == 's' else 1.5 if size.lower() == 'm' else 2
    fig, axes = plt.subplots(figsize=reduce(lambda x,y: tuple([x*multi,y*multi]),(8,6)), nrows=dims, ncols=dims) ## Must be NxN

    ## Now loop
    idx_ct = 0
    columns = [[r[d] for r in dat] for d in range(dims)]
    y_min, y_max = np.min(dat), np.max(dat)
    for i in range(dims):
        for j in range(dims):
            axes[i,j].plot(x=np.array(columns[i]), y=np.array(columns[j]), color=colors)
            axes[i,j].set_ylim([y_min, y_max])
            axes[i,j].set_xlim([y_min, y_max])
            axes[i,j].margins(0)

            ## Set the labels on the y-axis
            if j == 0: ## Only the left-most col gets the label
                axes[i,j].set_ylabel(names[i])
            if j == dims-1:
                axes[i,j].get_yaxis().tick_right()
            else:
                axes[i,j].get_yaxis().set_ticklabels([])


            ## Set the labels on the X-axis
            if i == 0: ## Only the top-most row gets the label
                axes[i,j].set_title(names[j])
            if i < dims-1:
                axes[i,j].get_xaxis().set_ticklabels([])

            idx_ct += 1

    fig.tight_layout()
    plt.show()

## Plot the clusters
plot_clusters(data=topic_maps_n, labels=labels, size='l')

Note that data can either be a pandas DataFrame or a numpy matrix, and labels is a list the same length as the dataframe (it is comprised of only integers). An example of what data may look like:

Example of data

And for the sake of reproducability:

data = pandas.DataFrame.from_records(numpy.array([[2.44593742e-01, 4.18387124e-02, 1.56175780e-02, 5.15885742e-04],
             [3.38941458e-01, 8.61882075e-02, 2.51219235e-02, 1.29532576e-03],
             [6.79218190e-02, 2.14741500e-02, 4.51219203e-03, 1.53073947e-06],
             [5.24470045e-01, 1.65668947e-01, 2.11256296e-02, 1.03752391e-04],
             [5.93903485e-01, 1.48081357e-01, 5.18316207e-02, 4.03474064e-02]]), 
             columns=['Topic1Magnitude','Topic2Magnitude','Topic3Magnitude','Topic4Magnitude'])

And labels could simply (arbitrarily) be:

labels = [0,0,1,0,1]

If anyone can help me identify where I'm going wrong, I would be very appreciative.

Upvotes: 3

Views: 1745

Answers (1)

TayTay
TayTay

Reputation: 7160

In case anyone else ever faces this, I solved the issue. In order to get the data to display, I had to change axes[i,j].plot(...) to axes[i,j].scatter(...), which I could've sworn I'd tried before:

Proper output

Upvotes: 1

Related Questions