shakedzy
shakedzy

Reputation: 2893

How to get the data array from scatter_matrix

I'm using the scatter_matrix of pandas, and I was wondering how do I get the 2D array plotted on each scatter matrix? Also, how do I recognize which AxesSubplot of the output is which matrix on the output graph?

Upvotes: 2

Views: 1557

Answers (1)

scatter_matrix is a convenience function of pandas, from the pandas.plotting submodule. While the documentation is scarce (and the docstring is only a bit more helpful), the example makes it quite straightforward to understand how it works. Consider the example in the documentation:

import numpy as np # only needed for the example input
import pandas as pd
from pandas.plotting import scatter_matrix

df = pd.DataFrame(np.random.randn(1000, 4), columns=['a', 'b', 'c', 'd'])
axs = scatter_matrix(df, alpha=0.2, figsize=(6,6), diagonal='kde')
axs[0,0].get_figure().show() # or import and call matplotlib.pyplot.show

example scatter_matrix output

Note the labels on the bottom and left axes: those indicate which columns of the input dataframe are plotted against one another in the given row/column. In the first column of plots the x axis corresponds to df.a, in the second row of plots the y axis corresponds to df.b etc. (and in the diagonals either densities or histograms of the respective columns are plotted). Consequently, transposed elements in the plot matrix correspond to a swap of the x and y data, i.e. reflection of the plot with respect to the x=y line. If you take a close look at the above figure you'll see that this is indeed the case.

In other words, you don't need to figure out the data from the individual axes, since you have direct control of your input data. In the off-diagonal axes axs[i,j] the x data are given by df[df.columns[j]] and the y data are given by df[df.columns[i]]. Here's a quick kludge to help visualize the order:

axs = scatter_matrix(df, alpha=0.2, figsize=(6,6), diagonal='kde')
for i in range(axs.shape[0]):
    for j in range(axs.shape[1]):
        if i == j:
            continue
        axs[i,j].set_title('x: {}, y: {}'.format(df.columns[j],df.columns[i]),
                           position=(0.5,0.5))

example scatter_matrix output with annotations showing the x and y sources for each off-diagonal plot

So while it would be possible to dig into the entrails of each of the AxesSubplot objects and extract the data from there, it's much simpler to use the respective columns of df directly. One exception are the diagonals: in case of a kernel density plot (assuming that the diagonal='kde' keyword was passed to scatter_matrix) you don't have direct access to the underlying data. In that case you can extract the lines from the diagonal AxesSubplots:

import matplotlib.pyplot as plt
index = 0
xdat,ydat = axs[index,index].get_lines()[0].get_data() # example for diagonal [0,0]
plt.figure()
plt.plot(xdat,ydat,'-')
plt.xlabel(df.columns[index])
plt.ylabel('density')

Upvotes: 2

Related Questions