user1064285
user1064285

Reputation: 165

Matplotlib/Seaborn - Different x labels on top and bottom

I am doing a plot using python Seaborn:

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib.colors as c
from matplotlib.patches import Patch
from matplotlib.ticker import FixedLocator

data = {
    'Sample1 stateA': np.random.randint(1, 11, 5),
    'Sample1 stateB': np.random.randint(1, 11, 5),
    'Sample1 stateC': np.random.randint(1, 11, 5),
    'Sample2 stateA': np.random.randint(1, 11, 5),
    'Sample2 stateB': np.random.randint(1, 11, 5),
    'Sample2 stateC': np.random.randint(1, 11, 5),
}

# Create the DataFrame
df = pd.DataFrame(data)

# Plot 
plt.figure(figsize=(12, 8))
ax = sns.heatmap(data
                        cbar=False,
                        linewidths=0.1,
                        linecolor='black',
                        annot=False,
                        vmin=1,
                        vmax=10
                        )

I want to add different labels on the x-axis top and bottom with something similar to this: Bottom labels:

xticks_labels = []
for i, label in enumerate(data.columns):
    # Split sample name and condition
    sample_name, condition = label.split(' ', 1)
    xticks_labels.append(condition)  # Only show condition name on the heatmap top

# Set the xticks to show the condition
ax.set_xticklabels(xticks_labels, rotation=90)

Top labels, this one is more ellaborate since I want to put one label in the center of a group of columns:

# Set the ticks for the sample
# Get the sample names in the order of the sample columns

# Create a secondary x-axis for the sample names
ax2 = ax.twiny()
sample_names = [item.split(" ")[0] for item in data.columns]
 
ax2.yaxis.tick_left()
ticks = []
labels = []
prev_label = None
for i, label in enumerate(sample_names):
    if label != prev_label:
        ticks.append(i)
        labels.append(label)
        prev_label = label
ticks.append(i + 1)
ax2.xaxis.set_minor_locator(FixedLocator(ticks))
ax2.xaxis.set_major_locator(FixedLocator([(t0 + t1) / 2 for t0, t1 in zip(ticks[:-1], ticks[1:])]))
ax2.set_xticklabels(labels, rotation=0)
ax2.tick_params(axis='both', which='major', length=0)
ax2.tick_params(axis='x', which='minor', length=60)

# Show ax2 on top so it doesn't get hidden 
ax2.spines['top'].set_position(('outward', 40))

# Adjust the layout to make sure both sets of labels are visible
plt.tight_layout()



# Save the plot to a file
plt.savefig(heatmap_file, dpi=300, bbox_inches='tight')

If I do the above creating a secondary axis for the top labels, I only get the bottom labels (secondary axis not being showed).

If I try to do the top labels using the same axis (using ax instead of ax2, I can see the top labels but not the bottom):

ax.yaxis.tick_left()
ticks = []
labels = []
prev_label = None
for i, label in enumerate(sample_names):
    if label != prev_label:
        ticks.append(i)
        labels.append(label)
        prev_label = label
ticks.append(i + 1)
ax.xaxis.set_minor_locator(FixedLocator(ticks))
ax.xaxis.set_major_locator(FixedLocator([(t0 + t1) / 2 for t0, t1 in zip(ticks[:-1], ticks[1:])]))
ax.set_xticklabels(labels, rotation=0)
ax.tick_params(axis='both', which='major', length=0)
ax.tick_params(axis='x', which='minor', length=60)

# Show ax2 on top so it doesn't get hidden 
ax.spines['top'].set_position(('outward', 40))

# Adjust the layout to make sure both sets of labels are visible
plt.tight_layout()

How can I see both to generate something like this? Desired result

Thank you.

Upvotes: 0

Views: 46

Answers (2)

user1064285
user1064285

Reputation: 165

import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import matplotlib.colors as c
from matplotlib.patches import Patch
from matplotlib.ticker import FixedLocator

data = {
    'Sample1 stateA': np.random.randint(1, 11, 5),
    'Sample1 stateB': np.random.randint(1, 11, 5),
    'Sample1 stateC': np.random.randint(1, 11, 5),
    'Sample2 stateA': np.random.randint(1, 11, 5),
    'Sample2 stateB': np.random.randint(1, 11, 5),
    'Sample2 stateC': np.random.randint(1, 11, 5),
}

# Create the DataFrame
df = pd.DataFrame(data)

# Plot 
plt.figure(figsize=(12, 8))
ax = sns.heatmap(data
                        cbar=False,
                        linewidths=0.1,
                        linecolor='black',
                        annot=False,
                        vmin=1,
                        vmax=10
                        )

# Top label
xticks_labels = []
for i, label in enumerate(data.columns):
    # Split sample name and condition
    sample_name, condition = label.split(' ', 1)
    xticks_labels.append(condition)  # Only show condition name on the heatmap top

# Set the xticks to show the condition
ax.set_xticklabels(xticks_labels, rotation=90)



# Bottom label
sample_names = [item.split(" ")[0] for item in plot_df_with_genes.columns]
     
         
# Calculate the midpoint of each group of conditions for top labels
mid_points = []
labels = [] 
for i in range(len(sample_names)):
    if i == 0 or sample_names[i] != sample_names[i-1]:
        start = i
    if i == len(sample_names) - 1 or sample_names[i] != sample_names[i+1]:
        mid_points.append((start + i +1) / 2)
        labels.append(sample_names[i])

# Add top labels using ax.text()
for i, midpoint in enumerate(mid_points):
    ax.text(midpoint, labels[i] , ha='center', va='top', transform=ax.get_xaxis_transform(), fontsize=12)

Upvotes: 0

If I understand your problem correctly, I think you overcomplicated this for yourself by opting to plot a secondary axis. I would use an approach with ax.text() (by the way, I don't know why the secondary or the primary axis was not displayed). Try this, I hope it works out:

 # Bottom labels
xticks_labels = [label.split(' ', 1)[1] for label in df.columns]
ax.set_xticklabels(xticks_labels, rotation=90)

# Top labels
sample_names = [label.split(' ', 1)[0] for label in df.columns]

# Calculate the midpoint of each group of conditions for top labels
mid_points = []
for i in range(len(sample_names)):
    if i == 0 or sample_names[i] != sample_names[i-1]:
        start = i
    if i == len(sample_names) - 1 or sample_names[i] != sample_names[i+1]:
        mid_points.append((start + i) / 2)

# Add top labels using ax.text()
for i, midpoint in enumerate(mid_points):
    ax.text(midpoint, -0.1, sample_names[midpoint], ha='center', va='center', transform=ax.get_xaxis_transform(), fontsize=12)

Upvotes: 0

Related Questions