Reputation: 165
I am doing a plot using python Seaborn:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib.colors as c
from matplotlib.patches import Patch
from matplotlib.ticker import FixedLocator
data = {
'Sample1 stateA': np.random.randint(1, 11, 5),
'Sample1 stateB': np.random.randint(1, 11, 5),
'Sample1 stateC': np.random.randint(1, 11, 5),
'Sample2 stateA': np.random.randint(1, 11, 5),
'Sample2 stateB': np.random.randint(1, 11, 5),
'Sample2 stateC': np.random.randint(1, 11, 5),
}
# Create the DataFrame
df = pd.DataFrame(data)
# Plot
plt.figure(figsize=(12, 8))
ax = sns.heatmap(data
cbar=False,
linewidths=0.1,
linecolor='black',
annot=False,
vmin=1,
vmax=10
)
I want to add different labels on the x-axis top and bottom with something similar to this: Bottom labels:
xticks_labels = []
for i, label in enumerate(data.columns):
# Split sample name and condition
sample_name, condition = label.split(' ', 1)
xticks_labels.append(condition) # Only show condition name on the heatmap top
# Set the xticks to show the condition
ax.set_xticklabels(xticks_labels, rotation=90)
Top labels, this one is more ellaborate since I want to put one label in the center of a group of columns:
# Set the ticks for the sample
# Get the sample names in the order of the sample columns
# Create a secondary x-axis for the sample names
ax2 = ax.twiny()
sample_names = [item.split(" ")[0] for item in data.columns]
ax2.yaxis.tick_left()
ticks = []
labels = []
prev_label = None
for i, label in enumerate(sample_names):
if label != prev_label:
ticks.append(i)
labels.append(label)
prev_label = label
ticks.append(i + 1)
ax2.xaxis.set_minor_locator(FixedLocator(ticks))
ax2.xaxis.set_major_locator(FixedLocator([(t0 + t1) / 2 for t0, t1 in zip(ticks[:-1], ticks[1:])]))
ax2.set_xticklabels(labels, rotation=0)
ax2.tick_params(axis='both', which='major', length=0)
ax2.tick_params(axis='x', which='minor', length=60)
# Show ax2 on top so it doesn't get hidden
ax2.spines['top'].set_position(('outward', 40))
# Adjust the layout to make sure both sets of labels are visible
plt.tight_layout()
# Save the plot to a file
plt.savefig(heatmap_file, dpi=300, bbox_inches='tight')
If I do the above creating a secondary axis for the top labels, I only get the bottom labels (secondary axis not being showed).
If I try to do the top labels using the same axis (using ax instead of ax2, I can see the top labels but not the bottom):
ax.yaxis.tick_left()
ticks = []
labels = []
prev_label = None
for i, label in enumerate(sample_names):
if label != prev_label:
ticks.append(i)
labels.append(label)
prev_label = label
ticks.append(i + 1)
ax.xaxis.set_minor_locator(FixedLocator(ticks))
ax.xaxis.set_major_locator(FixedLocator([(t0 + t1) / 2 for t0, t1 in zip(ticks[:-1], ticks[1:])]))
ax.set_xticklabels(labels, rotation=0)
ax.tick_params(axis='both', which='major', length=0)
ax.tick_params(axis='x', which='minor', length=60)
# Show ax2 on top so it doesn't get hidden
ax.spines['top'].set_position(('outward', 40))
# Adjust the layout to make sure both sets of labels are visible
plt.tight_layout()
How can I see both to generate something like this?
Thank you.
Upvotes: 0
Views: 46
Reputation: 165
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib.colors as c
from matplotlib.patches import Patch
from matplotlib.ticker import FixedLocator
data = {
'Sample1 stateA': np.random.randint(1, 11, 5),
'Sample1 stateB': np.random.randint(1, 11, 5),
'Sample1 stateC': np.random.randint(1, 11, 5),
'Sample2 stateA': np.random.randint(1, 11, 5),
'Sample2 stateB': np.random.randint(1, 11, 5),
'Sample2 stateC': np.random.randint(1, 11, 5),
}
# Create the DataFrame
df = pd.DataFrame(data)
# Plot
plt.figure(figsize=(12, 8))
ax = sns.heatmap(data
cbar=False,
linewidths=0.1,
linecolor='black',
annot=False,
vmin=1,
vmax=10
)
# Top label
xticks_labels = []
for i, label in enumerate(data.columns):
# Split sample name and condition
sample_name, condition = label.split(' ', 1)
xticks_labels.append(condition) # Only show condition name on the heatmap top
# Set the xticks to show the condition
ax.set_xticklabels(xticks_labels, rotation=90)
# Bottom label
sample_names = [item.split(" ")[0] for item in plot_df_with_genes.columns]
# Calculate the midpoint of each group of conditions for top labels
mid_points = []
labels = []
for i in range(len(sample_names)):
if i == 0 or sample_names[i] != sample_names[i-1]:
start = i
if i == len(sample_names) - 1 or sample_names[i] != sample_names[i+1]:
mid_points.append((start + i +1) / 2)
labels.append(sample_names[i])
# Add top labels using ax.text()
for i, midpoint in enumerate(mid_points):
ax.text(midpoint, labels[i] , ha='center', va='top', transform=ax.get_xaxis_transform(), fontsize=12)
Upvotes: 0
Reputation: 321
If I understand your problem correctly, I think you overcomplicated this for yourself by opting to plot a secondary axis. I would use an approach with ax.text()
(by the way, I don't know why the secondary or the primary axis was not displayed).
Try this, I hope it works out:
# Bottom labels
xticks_labels = [label.split(' ', 1)[1] for label in df.columns]
ax.set_xticklabels(xticks_labels, rotation=90)
# Top labels
sample_names = [label.split(' ', 1)[0] for label in df.columns]
# Calculate the midpoint of each group of conditions for top labels
mid_points = []
for i in range(len(sample_names)):
if i == 0 or sample_names[i] != sample_names[i-1]:
start = i
if i == len(sample_names) - 1 or sample_names[i] != sample_names[i+1]:
mid_points.append((start + i) / 2)
# Add top labels using ax.text()
for i, midpoint in enumerate(mid_points):
ax.text(midpoint, -0.1, sample_names[midpoint], ha='center', va='center', transform=ax.get_xaxis_transform(), fontsize=12)
Upvotes: 0