Andrey Ampilogov
Andrey Ampilogov

Reputation: 382

Make a heatmap of x,y,z data in Python

I am trying to draw a heatmap in Python. I have studied several tutorials but still can't achieve what I need. My data has 3 columns: X, Y (coordinates in the scatterplot) and cluster (a group/cluster each row is placed). The desired output should look like that (6 clusters and X,Y points distributed in the coloured areas):

enter image description here

My current code:

# libraries
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import kde
from scipy.interpolate import griddata

# Get the data (csv file is hosted on the web)
url = 'https://raw.githubusercontent.com/ampil/sandbox/master/latest-sales-sample.csv'
df = pd.read_csv(url, sep = ';')
df = df.dropna(axis=0, how='any', thresh=None, subset=None, inplace=False)
# create data
x = df['X']
y = np.log(df['Y'])
z = df['cluster']

# target grid to interpolate to
xi = yi = np.arange(0, 1.01, 0.01)
xi, yi = np.meshgrid(xi,yi)

# interpolate
zi = griddata((x,y),z,(xi,yi),method='cubic')

# plot
fig = plt.figure()
ax = fig.add_subplot(111)
ax.axis((x.min(), x.max(), y.min(), y.max()))
plt.contourf(xi, yi, zi, np.arange(0, 1.01, 0.01), cmap='coolwarm')
plt.plot(x,y,'k.')
plt.xlabel('x',fontsize=16)
plt.ylabel('y',fontsize=16)
plt.show()
plt.close(fig)

gives me

my graph

Later on, I plan to publish the graph via dash.

Any help is appreciated!

Upvotes: 2

Views: 2433

Answers (3)

user1269942
user1269942

Reputation: 3852

My answer is a small edit to answer provided by warped.

The difference is the inclusion of the parameter 'extend' in the contourf method call.

https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.contourf.html

For some more information about the minimum/maximum color-map behaviour, see:

https://matplotlib.org/3.1.1/gallery/images_contours_and_fields/contourf_demo.html#sphx-glr-gallery-images-contours-and-fields-contourf-demo-py

# libraries
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import kde
from scipy.interpolate import griddata

# Get the data (csv file is hosted on the web)
url = 'https://raw.githubusercontent.com/ampil/sandbox/master/latest-sales-sample.csv'
df = pd.read_csv(url, sep = ';')
df = df.dropna(axis=0, how='any', thresh=None, subset=None, inplace=False)
# create data
x = df['X']
y = np.log(df['Y'])
z = df['cluster']

#following 2 lines provided by user-warped
xi = np.arange(0, np.max(x), 0.1)
yi = np.arange(0, np.max(y), 0.1)

xi, yi = np.meshgrid(xi,yi)

# interpolate
zi = griddata((x,y),z,(xi,yi),method='cubic')

#define color map...which you can choose to modify with 'set_under' and 'set_over'
#as per: https://matplotlib.org/3.1.1/gallery/images_contours_and_fields/contourf_demo.html#sphx-glr-gallery-images-contours-and-fields-contourf-demo-py
cmap = plt.cm.get_cmap("coolwarm")

fig = plt.figure()
ax = fig.add_subplot(111)
ax.axis((x.min(), x.max(), y.min(), y.max()))

#added the 'extend' parameter to user:warped edit as per documentation of plt.contourf
#https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.contourf.html
plt.contourf(xi, yi, zi, levels=[1,2,3,4,5,6], cmap=cmap, extend='both')
plt.plot(x, y,'k.')

plt.xlabel('x',fontsize=16)
plt.ylabel('y',fontsize=16)
plt.show()

result image

As far as extending the colours beyond what they are now...you'd get a result that would look very different from the main area and would likely have little meaning. If this were a GIS application, I would have those exterior pixels to be "NODATA".

EDIT: Providing evidence that filling the exterior would look strange...

Using a canned gdal method to fill nodatas, this is what it would look like:

image with filled regions

This was quick and dirty and other methods likely exist but would probably look equally odd. Perhaps numpy.nan_to_num is another solution if you don't have gdal.

In case you're curious...here's the code (continues from previous code block):

import gdal
ds = gdal.GetDriverByName('MEM').Create('', zi.shape[1], zi.shape[0], 1, gdal.GDT_Float32)
in_band = ds.GetRasterBand(1)
in_band.SetNoDataValue(-9999)
in_band.FlushCache()

raster_data = np.copy(zi)
raster_data[np.isnan(zi)] = -9999

in_band.WriteArray(raster_data)

#this line takes a while to run...grab a coffee
result = gdal.FillNodata(in_band, None, maxSearchDist=20000, smoothingIterations=0)
in_band.FlushCache()

newz = in_band.ReadAsArray()

fig = plt.figure()
ax = fig.add_subplot(111)
ax.axis((x.min(), x.max(), y.min(), y.max()))

#added the 'extend' parameter as per documentation of plt.contourf
#https://matplotlib.org/3.1.1/api/_as_gen/matplotlib.pyplot.contourf.html
plt.contourf(xi, yi, newz, levels=[1,2,3,4,5,6], cmap=cmap, extend='both')
plt.plot(x, y,'k.')

plt.xlabel('x',fontsize=16)
plt.ylabel('y',fontsize=16)
plt.show()

Upvotes: 1

warped
warped

Reputation: 9481

looking at df.describe():

            id           Y              X           cluster
    count   706.000000  706.000000      706.000000  706.000000
    mean    357.035411  18401.784703    3217.385269 3.002833
    std     205.912934  46147.403750    950.665697  0.532616
    min     1.000000    278.000000      328.000000  1.000000
    25%     178.500000  3546.000000     2498.500000 3.000000
    50%     358.500000  6869.500000     3574.000000 3.000000
    75%     534.750000  17169.000000    3997.500000 3.000000
    max     712.000000  877392.000000   4321.000000 6.000000

X is between 328 and 4321, Y is between 278 and 887392.

your lines

xi = yi = np.arange(0, 1.01, 0.01) 
xi, yi = np.meshgrid(xi,yi)

create a grid with x,y values between zero and one. So, you are trying to interpolate using data that is very far away from your grid.

Setting

xi = np.arange(0, np.max(x), 0.1)
yi = np.arange(0, np.max(y), 0.1)

xi, yi = np.meshgrid(xi,yi)

and leaving out the np.arange...in plt.contour:

fig = plt.figure()
ax = fig.add_subplot(111)
ax.axis((x.min(), x.max(), y.min(), y.max()))
plt.contourf(xi, yi, zi, cmap='coolwarm') # <-- removed np.arange()
plt.plot(x,y,'k.')
plt.xlabel('x',fontsize=16)
plt.ylabel('y',fontsize=16)
plt.show()

enter image description here

using levels to draw contours: plt.contourf(xi, yi, zi, levels=[1,2,3,4,5,6], cmap='coolwarm')

enter image description here

Upvotes: 1

seralouk
seralouk

Reputation: 33197

import matplotlib
colors = ['red','green','blue','purple','black', 'coral']

fig = plt.figure(figsize=(8,8))
plt.scatter(x, y, c = z, cmap=matplotlib.colors.ListedColormap(colors))
plt.show()

enter image description here

Upvotes: 1

Related Questions