Reputation: 527
I want to add grid to subplots of plt like:
and this is my code(data
is a dict of (string, padas DataFrame)):
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
for row in ((ax1, ax2), (ax3, ax4)):
for axes in row:
data['DPTP'].plot(ax=axes, x="time", y="fitness",
label="DPTP(r = 30)", linestyle='dashed', grid=True)
for i in range(2):
lb = "Preset {}".format(i)
data[str(i)].plot(ax=ax2, x="time", y="fitness", label=lb, grid=True)
ax1.set_title("a")
ax1.grid(True)
for i in range(2, 4):
lb = "Preset {}".format(i)
data[str(i)].plot(ax=ax3, x="time", y="fitness", label=lb, grid=True)
ax2.set_title("b")
ax2.grid(True)
for i in range(4, 6):
lb = "Preset {}".format(i)
data[str(i)].plot(ax=ax4, x="time", y="fitness", label=lb, grid=True)
ax3.set_title("c")
ax3.grid(True)
for i in range(6, 10):
lb = "Preset {}".format(i)
data[str(i)].plot(ax=ax4, x="time", y="fitness", label=lb, grid=True)
ax4.set_title("d")
ax4.grid(True)
plt.show()
in the each for-loop I have grid=True
and I think I can set each subplot to have grid, but my plot only has grid on the first subplot:
And even
ax.set_title()
does not work for subplots.
Why?
Upvotes: 0
Views: 3966
Reputation: 3542
You're somewhat new so I'll be kind and provide a solution but you should have provided the mock data to work on, that's your job when asking a question here.
So here is the part you should have provided, construction of data:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
# Mock data
time = np.linspace(0.3, 2, 600)
fitness = lambda i: 1 / (time + 0.05 * (i + 1)) - 0.1 * (i + 1)
dataframe = lambda i: pd.DataFrame({"time": time, "fitness": fitness(i)})
data = {"DPTP": dataframe(0)} | {str(i): dataframe(i + 1) for i in range(10)}
Now let's create the axes, I just refactored a bit using dicts so plotting part is shorter and independant:
# Prepare axes and axes metadata
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
ax_titles = {ax1: "a", ax2: "b", ax3: "c", ax4: "c"}
key_to_ax = (
dict.fromkeys(["0", "1"], ax1)
| dict.fromkeys(["2", "3"], ax2)
| dict.fromkeys(["4", "5"], ax3)
| dict.fromkeys(["6", "7", "8", "9"], ax4)
)
And finally the plotting part. Here you have two choices, either:
grid=True
as argument to all the df.plot
callsax.grid()
once you're done plotting new lines on that ax
.# Plot
for ax in (ax1, ax2, ax3, ax4):
data["DPTP"].plot(
ax=ax, x="time", y="fitness", label="DPTP(r = 30)", linestyle="dashed"
)
for i in range(10):
data[str(i)].plot(ax=key_to_ax[str(i)], x="time", y="fitness", label=f"Preset {i}")
for ax in (ax1, ax2, ax3, ax4):
ax.set_title(ax_titles[ax])
ax.grid()
plt.show()
Upvotes: 3