durbachit
durbachit

Reputation: 4876

How can I get axis of different length but same scale for subplots in matplotlib?

If I have, for example, these 3 datasets to plot:

a = np.arange(0,10,1)
b = np.arange(2,6,1)
c = np.arange(5,10,1)

and obviously, I don't want them to share the same axis, because some of them would basically just have a big empty graph with a few datapoints somewhere in the corner. So I would ideally have subplots of different sizes, but all with the same scale (i.e. that the step 1 has the same length on all subplots). I know I could do this manually by setting the size of the figure, but for a larger number of datasets or not so nice numbers it would be a real pain. But searching through other questions or the documentation, I couldn't find anything that would, for example, set the fixed distance of ticks on axis or something. Please note that I am not asking about the aspect ratio. The aspect ratio can be different, I just need the same scale on an axis. (See the image below, which hopefully illustrates my problem. Note: no I don't need the scale bar in my actual plot, this is here for you to see how the scale is the same.) Thanks.

updated image

Upvotes: 2

Views: 3420

Answers (2)

durbachit
durbachit

Reputation: 4876

After doing a lot of research, it looks like there really isn't any simple command to do this. But first, giving a thought of the range of both x and y values of each subplot and their ratios, and the layout, GridSpec will do the job.

So for our example, the layout is as presented in the question, i.e. the biggest picture on top, the two smaller ones next to each other underneath. To make it easier, the y range is the same for all of them (but if it wasn't we would use the same calculations as for x). Now, knowing this layout, we can create a grid. The vertical span is 20 (because we have two rows of 4 plots with y-range 10) and we may want some space between them for axis labels, legend etc., so we'll add extra 5. The first plot has x range of 10, However, the second and third figures have the range of 4 and 5, which is 9 in total, and we may also want some space between them, so let us add 3 extra. So the horizontal grid will span over 12. Hence, we create a grid 25 x 12 and fit our plots in this grid as follows:

import numpy as np
import matplotlib.pyplot as plt
from matplotlib import gridspec

## GRIDSPEC INTRO - creating a grid to distribute the subplots with same scales and different sizes.
fig = plt.figure()
gs=gridspec.GridSpec(25,12)
## SUBPLOTS of GRIDSPEC
#the first (big) plot
axes1 = plt.subplot(gs[:10,:10])
ax1 = plt.plot(x,y) # plot some data here
ax1 = plt.xlim(1,10)
ax1 = plt.ylim(1,10)
# the second plot below the big one on the left
axes2 = plt.subplot(gs[15:,:4])
ax2 = plt.plot(x,y) # plot some data here
ax2 = plt.xlim(2,6)
ax2 = plt.ylim(1,10)
# the third plot - below the big one on the right
axes3 = plt.subplot(gs[15:,7:])
ax3 = plt.plot(x,y) # plot some data here
ax3 = plt.xlim(5,10)
ax3 = plt.ylim(1,10)
plt.show()

Upvotes: 2

Madelyne Velasco Mite
Madelyne Velasco Mite

Reputation: 368

Well after an hour:

__author__ = 'madevelasco'

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd


def group(x, max, min, iters):
    total_range = max - min
    for i in range(0, iters):
        if (x > min + i*total_range/iters and x <= min + (i+1)*total_range/iters):
            return i



def newPlots():
    df = pd.DataFrame(np.random.randn(100, 2), columns=['x', 'y'])
    df.plot(x = 'x', y = 'y', kind = 'scatter', alpha=0.5)
    plt.show()

    ##Sort by the column you want
    df.sort_values(['x'], ascending=[False], inplace=True)
    result = df.reset_index(drop=True).copy()

    #Number of groups you want
    iterations = 3
    max_range = df['x'].max()
    min_range = df['x'].min()
    total_range = max_range - min_range

    result['group'] = result.apply(lambda x: group(x['x'], max_range, min_range, iterations ), axis=1)
    print(result)

    for x in range (0, iterations):

        lower = min_range + (x)*total_range/iterations
        upper = min_range + (1+x)*total_range/iterations
        new = result[result['group'] == x]
        new.plot(x = 'x', y = 'y', kind = 'scatter', alpha=0.3)
        axes = plt.gca()
        axes.set_xlim([lower, upper])
        axes.set_ylim([df['y'].min(),df['y'].max()])
        plt.show()

if __name__ == '__main__':
    newPlots()

I used pandas to to this. Honestly, the idea of visualization is to have all data in one graph but not separated like that. To maintain the idea of your date and even for readability, one of the axis should be fixed. I did between edits

Image of all points

General

Sub plots

first subplot

second subplot

third subplot

Upvotes: 0

Related Questions