Another_coder
Another_coder

Reputation: 810

Interplay between plt.get_cmap and np.linspace

I am trying to work myself through the matplotlib example 'Discrete distribution as horizontal bar chart' here is the full code:

import matplotlib.pyplot as plt
import numpy as np



category_names = ['Strongly disagree', 'Disagree',
                  'Neither agree nor disagree', 'Agree', 'Strongly agree']
results = {
    'Question 1': [10, 15, 17, 32, 26],
    'Question 2': [26, 22, 29, 10, 13],
    'Question 3': [35, 37, 7, 2, 19],
    'Question 4': [32, 11, 9, 15, 33],
    'Question 5': [21, 29, 5, 5, 40],
    'Question 6': [8, 19, 5, 30, 38]
}


def survey(results, category_names):
    """
    Parameters
    ----------
    results : dict
        A mapping from question labels to a list of answers per category.
        It is assumed all lists contain the same number of entries and that
        it matches the length of *category_names*.
    category_names : list of str
        The category labels.
    """
    labels = list(results.keys())
    data = np.array(list(results.values()))
    data_cum = data.cumsum(axis=1)
    category_colors = plt.get_cmap('RdYlGn')(
        np.linspace(0.15, 0.85, data.shape[1]))

    fig, ax = plt.subplots(figsize=(9.2, 5))
    ax.invert_yaxis()
    ax.xaxis.set_visible(False)
    ax.set_xlim(0, np.sum(data, axis=1).max())

    for i, (colname, color) in enumerate(zip(category_names, category_colors)):
        widths = data[:, i]
        starts = data_cum[:, i] - widths
        ax.barh(labels, widths, left=starts, height=0.5,
                label=colname, color=color)
        xcenters = starts + widths / 2

        r, g, b, _ = color
        text_color = 'white' if r * g * b < 0.5 else 'darkgrey'
        for y, (x, c) in enumerate(zip(xcenters, widths)):
            ax.text(x, y, str(int(c)), ha='center', va='center',
                    color=text_color)
    ax.legend(ncol=len(category_names), bbox_to_anchor=(0, 1),
              loc='lower left', fontsize='small')

    return fig, ax


survey(results, category_names)
plt.show()

What I am confused about is the use of the get_cmap when the category colors are set, so this line:

category_colors = plt.get_cmap('RdYlGn')(
    np.linspace(0.15, 0.85, data.shape[1]))

I understand that the first part (plt.get_cmap('RdYlGn') gets a color name instance, but I don't understand what the np.linespace part is doing? Could someone explain?

Many thanks in advance!

Upvotes: 1

Views: 2005

Answers (1)

Diziet Asahi
Diziet Asahi

Reputation: 40737

It is creating a list of 6 colors, extracted from the colormap.

a colormap is a way to map a value in the interval [0-1] to a color.

In this problem, there are 6 Questions and we want to assign a color to each of them. So they create a list of 6 values that must be contained in the interval [0-1], but to avoid the extreme colors, they in fact chose to pick them in the interval [0.15-0.85] (in many colormaps, the extremes could be white, or for cyclical colormaps, the two extremes could be the same color)

np.linspace(0.15, 0.85, 6)
>>> [0.15, 0.29, 0.43, 0.57, 0.71, 0.85]

then those 6 numbers are mapped to colors in the RdYlGn colormap

enter image description here

Upvotes: 4

Related Questions