toothsie
toothsie

Reputation: 255

Can I overlay a Seaborn plot onto a Matplotlib graph?

I have a function that produces a matplotlib map. I then want to overlay a seaborn heat map on top of this map, and have both maps exactly the same size on top of each other, while being able to see the details of both maps. Is it possible? Please see my code below.

def draw_map():
    fig=plt.figure()
    fig.set_size_inches(14.5, 8.8)
    ax=fig.add_subplot(1,1,1)
    
    #Map Outline & Centre Line
    plt.plot([0,0],[0,88], color="black")
    plt.plot([0,145],[88,88], color="black")
    plt.plot([145,145],[88,0], color="black")
    plt.plot([145,0],[0,0], color="black")

    ly97 = [39,49] 
    lx97 = [72.5,72.5]
    plt.plot(lx97,ly97,color="black")
    
    seaborn.heatmap(data)
    plt.ylim(0, 88)
    plt.xlim(0, 145)
                    
    #Display Map
    plt.show()
    

For some reason the seaborn heatmap appears tiny in comparison to the matplotlib map. The data in the seaborn heatmap contains values between 0 and 1 only, if this helps. Thanks in advance.

Upvotes: 0

Views: 1019

Answers (1)

JohanC
JohanC

Reputation: 80339

When drawing an MxN array as a heatmap, seaborn creates it with an x-dimension from 0 to N-1 and a y-dimension from 0 to M-1. There doesn't seem to be a way to provide your own dimensions. As seaborn calls matplotlib's pcolormesh() to draw the heatmap, you can call it directly. pcolormesh() does accept parameters to set the x and y dimensions.

The example below uses the standard "object-oriented" interface for pyplot. Alpha and green lines are used, to get some more contrast between the lines and the heatmap with seaborn's default colormap.

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

fig, ax = plt.subplots(figsize=(14.5, 8.8))

# Map Outline & Centre Line
ax.plot([0, 0], [0, 88], color="lime", lw=3)
ax.plot([0, 145], [88, 88], color="lime", lw=3)
ax.plot([145, 145], [88, 0], color="lime", lw=3)
ax.plot([145, 0], [0, 0], color="lime", lw=3)

ly97 = [39, 49]
lx97 = [72.5, 72.5]
ax.plot(lx97, ly97, color="lime", lw=3)

M = 20
N = 30
data = np.random.rand(M, N)
# sns.heatmap(data)
ax.pcolormesh(np.linspace(0, 145, N+1), np.linspace(0, 88, M+1), data, alpha=0.4,
              cmap=sns.color_palette("rocket", as_cmap=True))
# ax.set_ylim(0, 88)
# ax.set_xlim(0, 145)

plt.show()

example plot

Upvotes: 2

Related Questions