Steve
Steve

Reputation: 634

Matplotlib: (Automating) Plotting unknown number of data sources/line/plots on 1 subplot

(Sorry for my title not being great!)

Ok, so I'm trying to automate one of my plotting codes. I think this is the better board to use (i.e. rather than Code Review) as it's a question with a specific goal rather than generally improving it. Apologies if I'm mistaken in my assessment of this.

For this I need to be able to plot an unknown number of different data sources (what will form lines on the plot); all within the same subplot. Must things I have found are geared up for having a new subplot for each data source which is not what I'm after.

Example code for if I have say 3 data sources:

import matplotlib.pyplot as plt

data_y1 = [1, 2, 3, 4, 5]
data_x1 = [1, 1.5, 2, 2.5, 3]

data_y2 = [1, 2, 3, 4, 5, 6, 2]
data_x2 = [1, 2, 3, 4, 6, 9, 10]

data_y3 = [1, 3, 5, 7]
data_x3 = [1, 4, 9, 16]

fig1 = plt.figure()
ax1 = fig1.add_subplot(111)

a1, = ax1.plot(data_x1, data_y1, label="Data 1", color='g')
a2, = ax1.plot(data_x2, data_y2, label="Data 2", color='r')
a3, = ax1.plot(data_x3, data_y3, label="Data 3", color='c')

ax1.set_xlabel("Number of Hellos", fontsize=15)
ax1.set_ylabel("Number of Worlds", fontsize=18)

fig1.legend( (a1, a2, a3), ("Data 1", "Data 2", "Data 3"), loc='lower center', fancybox=True, ncol=3, fontsize=20)

mng = plt.get_current_fig_manager()
mng.window.showMaximized()

plt.show()

So that works. However my problem is, I have no idea how many data sources I'll have. It could be 10 (e.g. data_y10, etc.) or it could just be the one. So I can't do say (a1, a2, a3) I'm struggling to automate this. I've been trying to use dictionaries, however they seem to mess up the legend.

Any advice would be very appreciated.

Upvotes: 0

Views: 3274

Answers (1)

Adobe
Adobe

Reputation: 13477

Assuming you have all your data in a list of lists:

#!/usr/bin/env python3

import matplotlib.pyplot as plt
from itertools import cycle

def main():

    colors = cycle(["aqua", "black", "blue", "fuchsia", "gray", "green", "lime", "maroon", "navy", "olive", "purple", "red", "silver", "teal", "yellow"])

    data = [
        [
            [1,   2, 3,   4, 5],
            [1, 1.5, 2, 2.5, 3]
        ], 
        [
            [1, 2, 3, 4, 5, 6,  2],
            [1, 2, 3, 4, 6, 9, 10]
        ], 
        [
            [1, 3, 5,  7],
            [1, 4, 9, 16]
        ]
    ]

    fig = plt.figure()
    ax = fig.add_subplot(111)

    for i, item in enumerate(data):
        ax.plot(item[0], item[1], label="Data " + str(i), color=next(colors))

    ax.set_xlabel("Number of Hellos", fontsize=15)
    ax.set_ylabel("Number of Worlds", fontsize=18)

    ax.legend(loc="best")

    ax.margins(0.1)
    fig.tight_layout()

    plt.savefig("mwe.png")

if __name__ == "__main__":
    main()

enter image description here

Upvotes: 3

Related Questions