Reputation: 634
(Sorry for my title not being great!)
Ok, so I'm trying to automate one of my plotting codes. I think this is the better board to use (i.e. rather than Code Review) as it's a question with a specific goal rather than generally improving it. Apologies if I'm mistaken in my assessment of this.
For this I need to be able to plot an unknown number of different data sources (what will form lines on the plot); all within the same subplot. Must things I have found are geared up for having a new subplot for each data source which is not what I'm after.
Example code for if I have say 3 data sources:
import matplotlib.pyplot as plt
data_y1 = [1, 2, 3, 4, 5]
data_x1 = [1, 1.5, 2, 2.5, 3]
data_y2 = [1, 2, 3, 4, 5, 6, 2]
data_x2 = [1, 2, 3, 4, 6, 9, 10]
data_y3 = [1, 3, 5, 7]
data_x3 = [1, 4, 9, 16]
fig1 = plt.figure()
ax1 = fig1.add_subplot(111)
a1, = ax1.plot(data_x1, data_y1, label="Data 1", color='g')
a2, = ax1.plot(data_x2, data_y2, label="Data 2", color='r')
a3, = ax1.plot(data_x3, data_y3, label="Data 3", color='c')
ax1.set_xlabel("Number of Hellos", fontsize=15)
ax1.set_ylabel("Number of Worlds", fontsize=18)
fig1.legend( (a1, a2, a3), ("Data 1", "Data 2", "Data 3"), loc='lower center', fancybox=True, ncol=3, fontsize=20)
mng = plt.get_current_fig_manager()
mng.window.showMaximized()
plt.show()
So that works. However my problem is, I have no idea how many data sources I'll have. It could be 10 (e.g. data_y10, etc.) or it could just be the one. So I can't do say (a1, a2, a3) I'm struggling to automate this. I've been trying to use dictionaries, however they seem to mess up the legend.
Any advice would be very appreciated.
Upvotes: 0
Views: 3274
Reputation: 13477
Assuming you have all your data in a list of lists:
#!/usr/bin/env python3
import matplotlib.pyplot as plt
from itertools import cycle
def main():
colors = cycle(["aqua", "black", "blue", "fuchsia", "gray", "green", "lime", "maroon", "navy", "olive", "purple", "red", "silver", "teal", "yellow"])
data = [
[
[1, 2, 3, 4, 5],
[1, 1.5, 2, 2.5, 3]
],
[
[1, 2, 3, 4, 5, 6, 2],
[1, 2, 3, 4, 6, 9, 10]
],
[
[1, 3, 5, 7],
[1, 4, 9, 16]
]
]
fig = plt.figure()
ax = fig.add_subplot(111)
for i, item in enumerate(data):
ax.plot(item[0], item[1], label="Data " + str(i), color=next(colors))
ax.set_xlabel("Number of Hellos", fontsize=15)
ax.set_ylabel("Number of Worlds", fontsize=18)
ax.legend(loc="best")
ax.margins(0.1)
fig.tight_layout()
plt.savefig("mwe.png")
if __name__ == "__main__":
main()
Upvotes: 3