Tommy Lees
Tommy Lees

Reputation: 1373

Plot categorical data in matplotlib when the values are unevenly spaced

I need to create a 2D image of gridded data with unevenly spaced values. I am plotting a categorical dataset where the categories are encoded with numerical values corresponding to a specific label.

I need to be able to use the formatter to assign a different color to each category in the dataset. This should preferably be flexible because the true dataset has ~30 unique categories that I am plotting. Thus I should have a unique color for when the value is 10 and when it is 40.

Making the example data to demonstrate

import xarray as xr
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

time = pd.date_range('2010-01-31', '2015-12-31', freq='M')
lat = np.linspace(0, 1, 224)
lon = np.linspace(0, 1, 176)
valid_vals = [10., 40., 50., 60.]
labels = ['type_1', 'type_2', 'type_3', 'type_4']
lookup = dict(zip(valid_vals, labels))

values = np.random.choice(valid_vals, size=(len(time), len(lat), len(lon)))
rand_nans = np.random.random(size=(len(time), len(lat), len(lon))) < 0.3
values[rand_nans] = np.nan

coords = {'time': time, 'lat': lat, 'lon': lon}
dims = ['time', 'lat', 'lon']

ds = xr.Dataset({'lc_code': (dims, values)}, coords=coords)

# convert to numpy array (only the first timestep)
im = ds.isel(time=0).lc_code.values

ds
Out[]:
<xarray.Dataset>
Dimensions:  (lat: 224, lon: 176, time: 72)
Coordinates:
  * time     (time) datetime64[ns] 2010-01-31 2010-02-28 ... 2015-12-31
  * lat      (lat) float64 0.0 0.004484 0.008969 0.01345 ... 0.991 0.9955 1.0
  * lon      (lon) float64 0.0 0.005714 0.01143 0.01714 ... 0.9886 0.9943 1.0
Data variables:
    lc_code  (time, lat, lon) float64 50.0 nan 60.0 50.0 ... 40.0 10.0 40.0 10.0

Just plotting the image data alone has two problems: 1) The tick labels are not the strings defined in labels 2) The colorbar is evenly spaced but the values are not. Such that we have values at 10, 40, 50, 60

plt.imshow(im, cmap=plt.cm.get_cmap('tab10', len(valid_vals)))
plt.colorbar()

simple imshow

So I have tried with the FuncFormatter. However this image still has the problem that no values are mapped to the type_2 color despite the tick label lining up in the centre of the colorbar.

fig, ax = plt.subplots(figsize=(12, 8))

plt.imshow(im, cmap=plt.cm.get_cmap('tab10', len(valid_vals)))

# calculate the POSITION of the tick labels
min_ = min(valid_vals)
max_ = max(valid_vals)
positions = np.linspace(min_, max_, len(valid_vals))
val_lookup = dict(zip(positions, labels))

def formatter_func(x, pos):
    'The two args are the value and tick position'
    val = val_lookup[x]
    return val

formatter = plt.FuncFormatter(formatter_func)

# We must be sure to specify the ticks matching our target names
plt.colorbar(ticks=positions, format=formatter, spacing='proportional');

# set the colorbar limits so that the ticks are evenly spaced
plt.clim(0, 70)

My attempt at mapping the values to the labels

But this code forces the second category (values of 40, type_2) to not be shown with the color the tick lines up with. Therefore, the colorbar isn't effectively reflecting the data in the image.

(im == 40).mean()

Out[]:
0.17347301136363635

Upvotes: 3

Views: 2594

Answers (1)

Dominic McLoughlin
Dominic McLoughlin

Reputation: 278

The reason why no colours have been mapped to type_2 colour in your first plot is that there are no values between 23 and 35, which is roughly the range that would be assigned to red.

What you could try instead is using a ListedColormap.

import xarray as xr
import matplotlib.pyplot as plt
from matplotlib import colors
import numpy as np
import pandas as pd

time = pd.date_range('2010-01-31', '2015-12-31', freq='M')
lat = np.linspace(0, 1, 224)
lon = np.linspace(0, 1, 176)
valid_values = [10., 40., 50., 60.]
labels = ['type_1', 'type_2', 'type_3', 'type_4']
lookup = dict(zip(valid_values, labels))

values = np.random.choice(valid_values, size=(len(time), len(lat), len(lon)))
rand_nans = np.random.random(size=(len(time), len(lat), len(lon))) < 0.3
values[rand_nans] = np.nan

coords = {'time': time, 'lat': lat, 'lon': lon}
dims = ['time', 'lat', 'lon']

ds = xr.Dataset({'lc_code': (dims, values)}, coords=coords)

# convert to numpy array (only the first timestep)
im = ds.isel(time=0).lc_code.values

# Build a listed colormap.
c_map = colors.ListedColormap(['white', 'red', 'blue', 'green'])
bounds = [-15, 35, 45, 55, 65]
norm = colors.BoundaryNorm(bounds, c_map.N)

# Plot the image with a color bar
im = plt.imshow(im, cmap=c_map, norm=norm)
c_bar = plt.colorbar(
    im, cmap=c_map, norm=norm, boundaries=bounds, ticks=[10, 40, 50, 60])
c_bar.ax.set_xticklabels(['type_1', 'type_2', 'type_3', 'type_4'])
plt.show()

This gives this output: Plot

In order to get the labels to the centre of the colorbar region you only need to make sure that the tick value (in the plt.colorbar paramter) is exactly halfway between the relevant bounds. I hardcoded these, but you could easily compute these automatically! I believe it doesn't matter that the bounds are not equally spaced, because the ListedColormap is implicitly categorical, so it understands to make the colorbar equal sizes for each category.

Hope this helps!

Upvotes: 3

Related Questions