mahmood
mahmood

Reputation: 24675

Plotting multiple dataframes in one chart

In the following code, in each iteration a dataframe is read from a dictionary and is plotted. My intention is see all plots in one chart, but I see multiple charts in separate windows.

def plot(my_dict):
    for key in my_dict:
        df = my_dict[key]
        df.plot.line(x='c', y='i')
    plt.show()

I see some tutorials about that, e.g. here, but it seems that they work when I want to call df.plot() with different columns. Here in my code, I am plotting different dataframes. Any idea on how to fix the code?

P.S: I am running the command from Linux terminal.

enter image description here

Upvotes: 1

Views: 318

Answers (3)

Parfait
Parfait

Reputation: 107567

Consider concatenating all data together to plot data frame once. Specifically, horizontally merge with pandas.concat on the c (i.e., shared x-axis variable), renaming i (i.e., y-axis and legend series) for each dict key, and then call DataFrame.plot only once:

def plot(my_dict):
    graph_df = pd.concat(
        [
            df[['c', 'i']].rename({'i': k}, axis=1).set_index('c')
            for k, df in my_dict.items()
        ],
        axis=1
     )

    graph_df.plot(kind="line")

    plt.show()

Upvotes: 1

ttveen
ttveen

Reputation: 41

According to these docs you can pass an matplotlib axes object to df.plot.line() (which passes it to df.plot(). I think something like this might work:

def plot(my_dict, axes_obj):
    for key in my_dict:
        df = my_dict[key]
        df.plot.line(x='c', y='i', ax=axes_obj)
    plt.show()

There are several ways to obtain an axes object, for example:

fig = plt.figure()
axes = fig.add_subplot(1, 1, 1)

or to get the current axes:

plt.gca()

Upvotes: 0

Ingwersen_erik
Ingwersen_erik

Reputation: 2263

Pandas uses matplotlib.pyplot as its default backend to create plots. Therefore, one possible solution would be to use matplotlib.pyplot.plot directly. Here's a possible implementation:

from __future__ import annotations
import matplotlib.pyplot as plt


def plot(
    my_dict: dict,
    xcol: str,
    ycol: str,
    grid: bool = True,
    labels: bool | tuple | list = True,
    figsize: tuple | None = None,
    legend: bool | list | tuple | None = True,
    title: str | bool | None = True,
):
    """Plot a dictionary of dataframes.

    Parameters
    ----------
    my_dict : dict
        Dictionary of dataframes.
    xcol : str
        Name of the column to use as x-axis.
    ycol : str
        Name of the column to use as y-axis.
    grid : bool, default=True
        If True, show grid.
    labels : bool | tuple | list, default=True
        If True, use xcol and ycol as labels each axis labels.
        If tuple or list, use `labels[0]` as x-label and `labels[1]` as y-label.
    figsize : tuple | None, optional
        Size of the figure. First value from tuple represents the
        width, and second value the height of the plot. Defaults to (10, 10)
    legend : bool | list | tuple | None, optional
        If True, use keys of `my_dict` as legend.
        If list or tuple, use list or tuple as legend.
    title : str | bool | None, optional
        If True, use `xcol` and `ycol` as title.
        If you specify a string, use the specified value instead.
        Set it to None, or False if you don't want to display the
        plot's title.
    """
    if figsize is None:
        figsize = (10, 10)

    plt.figure(figsize=figsize)
    plt.plot(*[[_df[xcol], _df[ycol]] for _df in my_dict.values()])

    if legend == True:
        plt.legend(my_dict.keys())
    elif isinstance(legend, (list, tuple)):
        plt.legend(legend)

    if grid:
        plt.grid(True)

    if labels == True:
        plt.xlabel(xcol, fontsize=16)
        plt.ylabel(ycol, fontsize=16)
    elif isinstance(labels, (list, tuple)):
        plt.xlabel(labels[0], fontsize=16)
        plt.ylabel(labels[1], fontsize=16)

    if title == True:
        plt.title(f'${xcol} \\times {ycol}$', fontsize=20)
    elif isinstance(title, str):
        plt.title(title, fontsize=20)
    
    min_x = min(_df['c'].min() for _df in my_dict.values())
    min_y = min(_df['i'].min() for _df in my_dict.values())
    max_x = max(_df['c'].max() for _df in my_dict.values())
    max_y = max(_df['i'].max() for _df in my_dict.values())

    plt.xlim(min_x, max_x * 1.01)
    plt.ylim(min_y, max_y * 1.01)
    plt.show()

Example


import numpy as np
import pandas as pd


d = {
    char: pd.DataFrame(
        {"c": np.random.randint(0, 100, 20), "i": np.random.randint(0, 100, 20)}
    )
    for char in "abcdef"
}

plot(d, 'c', 'i')

Output:

enter image description here

Simplified Version

If you want a stripped down version of the plot function, you could write it like so:

def plot2(my_dict: dict):
    """Plot a dictionary of dataframes.

    Parameters
    ----------
    my_dict : dict
        Dictionary of dataframes.
    """
    plt.plot(*[[_df['c'], _df['i']] for _df in my_dict.values()])
    plt.show()

Example


import numpy as np
import pandas as pd


d = {
    char: pd.DataFrame(
        {"c": np.random.randint(0, 100, 20), "i": np.random.randint(0, 100, 20)}
    )
    for char in "abcdef"
}

plot2(d, 'c', 'i')

Output:

enter image description here

Upvotes: 0

Related Questions