VGB
VGB

Reputation: 497

How to plot the figure of varying number of subplots according to the input?

Hiii

I am giving a list of data as in put list1=[T(x,y,z),U(x,y,z),V(x,y,z) ....... n(x,y,z)]

Each T, U and V have three dimensions. The final figure of subplots should have

T (x,y)         T(x,z)      T(y,z)

U (x,y)         U(x,z)      U(y,z)

V (x,y)         V(x,z)      V(y,z)

How to make a function to give such a figure as an output?

I am trying something like

nrow=len(list)
ncol=3

fig = plt.figure(figsize=(9, 6))
For i in list:
    for k in (ncol*nrow):
        sub2 = fig.add_subplot(nrow,ncol,1)
        plt.scatter(list[i](x),list[i](y))
        sub3 = fig.add_subplot(nrow,ncol,2)
        plt.scatter(list[i](x),list[i](z))
        sub3 = fig.add_subplot(nrow,ncol,3)
        plt.scatter(list[i](y),list[i](z))
        

I Don't know how to find a general solution for this.. It seems so silly though

Upvotes: 0

Views: 184

Answers (1)

JohanC
JohanC

Reputation: 80279

Here is a possible interpretation of what you seem to be trying to do:

import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

list1 = [pd.DataFrame(np.random.rand(20, 3), columns=['x', 'y', 'z']) for _ in range(np.random.randint(2, 6))]

nrow = len(list1)
ncol = 3

fig = plt.figure(figsize=(9, 6))
for row, list_i in enumerate(list1):
    ax1 = fig.add_subplot(nrow, ncol, row * 3 + 1)
    ax1.scatter(list_i['x'], list_i['y'])
    ax1.set_xlabel('x')
    ax1.set_ylabel('y')
    ax2 = fig.add_subplot(nrow, ncol, row * 3 + 2)
    ax2.scatter(list_i['x'], list_i['z'])
    ax2.set_xlabel('x')
    ax2.set_ylabel('z')
    ax3 = fig.add_subplot(nrow, ncol, row * 3 + 3)
    ax3.scatter(list_i['y'], list_i['z'])
    ax3.set_xlabel('y')
    ax3.set_ylabel('z')
plt.tight_layout()
plt.show()

grid of subplots

Using the object-oriented API and making the loops more "pythonic", the same code could look like:

fig, axs = plt.subplots(ncols=ncol, nrows=nrow, figsize=(9, 6), squeeze=False)
for row_axs, list_i in zip(axs, list1):
    for ax, (first, second) in zip(row_axs, [('x', 'y'), ('x', 'z'), ('y', 'z')]):
        ax.scatter(list_i[first], list_i[second])
        ax.set_xlabel(first)
        ax.set_ylabel(second)
plt.tight_layout()

Upvotes: 1

Related Questions