Will
Will

Reputation: 651

Formatting a broken y axis in python matplotlib

I have a (reasonably complicated) bar chart that I'm working on in matplotlib. It contains summary data from a number of sources which are each labelled along the x axis, with a range of results on the y axis. A number of results are outliers, and I've attempted to use a broken y axis to show these results without distorting the whole graph using a combination of this method for inserting a broken y axis, and this method for aligning subplots on a grid (the outliers are concentrated around a specific point, so the upper graph can be quite small).

The resulting graph looks a bit like this

Example graph

The problem is that the diagonal lines are clearly at different angles above and below the break in the y axis. I don't understand why.

The code I am using is below. Apologies for the complexity, I've had to do a lot of modifications to different axes to make this work...

    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    from matplotlib.gridspec import GridSpec

    data = pd.DataFrame.from_dict(
        {
            "low": {
                "Record 1": 5,
                "Record 2": 10,
                "Record 3": 15,
                "Record 4": 20,
                "Record 5": 25,
                "Record 6": 30,
                "Record 7": 35,
                "Record 8": 40,
                "Record 9": 45,
                "Record 10": 50,
            },
            "high": {
                "Record 1": 25,
                "Record 2": 100,
                "Record 3": 225,
                "Record 4": 25,
                "Record 5": 100,
                "Record 6": 10000,
                "Record 7": 25,
                "Record 8": 100,
                "Record 9": 225,
                "Record 10": 25,
            },
        }
    )

    mm = (146, 90)  # x value then y value
    inches = (mm[0] / 25.4, mm[1] / 25.4)

    fig = plt.figure(figsize=inches)
    fig.text(0.02, 0.6, r"Y axis label", va="center", rotation="vertical", fontsize=12)
    gs = GridSpec(2, 2, height_ratios=[1, 4])

    ax = fig.add_subplot(gs.new_subplotspec((0, 0), colspan=2))
    ax2 = fig.add_subplot(gs.new_subplotspec((1, 0), colspan=2))
    palette = plt.get_cmap("tab20")

    indx = np.arange(len(data.index))

    labs = data.index.tolist()
    labs.insert(0, "")

    ax.tick_params(axis="both", which="major", labelsize=10)
    ax2.tick_params(axis="both", which="major", labelsize=10)
    ax2.set_xticklabels((labs), rotation=45, fontsize=10, horizontalalignment="right")
    ax.set_xticklabels(())
    ax.set_xticks(np.arange(-1, len(data.index) + 1, 1.0))
    ax2.set_xticks(np.arange(-1, len(data.index) + 1, 1.0))

    ax.set_yticks(np.arange(0, max(data["high"]) + 10, 100))
    ax2.set_yticks(np.arange(0, max(data["high"]) + 10, 100))

    # plot the same data on both axes
    bar_lower = ax2.bar(
        x=indx,
        height=data["high"] - data["low"],
        bottom=data["low"],
        width=-0.5,
        align="center",
        color=palette(1),
        edgecolor="k",
        linewidth=0.5,
        zorder=10,
    )

    bar_upper = ax.bar(
        x=indx,
        height=data["high"] - data["low"],
        bottom=data["low"],
        width=-0.5,
        align="center",
        color=palette(1),
        edgecolor="k",
        linewidth=0.5,
        zorder=10,
    )

    # zoom-in / limit the view to different portions of the data
    ax.set_ylim(9950, 10050)  # outliers only
    ax2.set_ylim(0, 450)  # most of the data
    ax.set_xlim(-0.5, len(data.index) - 0.25)  # outliers only
    ax2.set_xlim(-0.5, len(data.index) - 0.25)  # most of the data


    ax.spines["bottom"].set_visible(False)
    ax2.spines["top"].set_visible(False)

    ax.grid(color="k", alpha=0.5, linestyle=":", zorder=1)
    ax2.grid(color="k", alpha=0.5, linestyle=":", zorder=1)

    ax.tick_params(axis="x", which="both", length=0)
    ax.tick_params(labeltop="off")
    ax2.tick_params(labeltop="off")
    ax2.xaxis.tick_bottom()

    d = 0.015  # how big to make the diagonal lines in axes coordinates
    # arguments to pass to plot, just so we don't keep repeating them
    kwargs = dict(transform=ax.transAxes, color="k", clip_on=False)  # linewidth=1)
    ax.plot((-d, +d), (-d, +d), **kwargs)  # top-left diagonal
    ax.plot((1 - d, 1 + d), (-d, +d), **kwargs)  # top-right diagonal

    kwargs.update(transform=ax2.transAxes)  # switch to the bottom axes
    ax2.plot((-d, +d), (1 - d, 1 + d), **kwargs)  # bottom-left diagonal
    ax2.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)  # bottom-right diagonal

    plt.subplots_adjust(
        top=0.943, bottom=0.214, left=0.103, right=0.97, hspace=0.133, wspace=0.062
    )
    plt.show()

Upvotes: 4

Views: 9280

Answers (1)

Will
Will

Reputation: 651

OK, well I have made some edits and it now works (just not quite as I'd originally intended) and there is a new solution here which should be pushed to the matplotlib page soon.

The key code is this section

# arguments to pass to plot, just so we don't keep repeating them
kwargs = dict(transform=ax.transAxes, color='k', clip_on=False)
ax.plot((-d, +d), (-d, +d), **kwargs)        # top-left diagonal
ax.plot((1 - d, 1 + d), (-d, +d), **kwargs)  # top-right diagonal

kwargs.update(transform=ax2.transAxes)  # switch to the bottom axes
ax2.plot((-d, +d), (1 - d, 1 + d), **kwargs)  # bottom-left diagonal
ax2.plot((1 - d, 1 + d), (1 - d, 1 + d), **kwargs)  # bottom-right diagonal

You can amend it to

axis_break1 = 450
axis_break2 = 9951
x_min = -0.75
x_max = len(data.index)
l = 0.2  # "break" line length
kwargs = dict(color="k", clip_on=False, linewidth=1)
ax.plot((x_min - l, x_min + l), (axis_break2, axis_break2), **kwargs)# top-left
ax.plot((x_max - l, x_max + l), (axis_break2, axis_break2), **kwargs)# top-right
ax2.plot((x_min - l, x_min + l), (axis_break1, axis_break1), **kwargs)# bottom-left
ax2.plot((x_max - l, x_max + l), (axis_break1, axis_break1), **kwargs)# bottom-right

Which leaves us with a neat (if slightly less fancy) result. resulting graph

Or a revised (and more elegant) version (from ImportanceOfBeingErnest):

d = .25  # proportion of vertical to horizontal extent of the slanted line
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12,
              linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax.plot([0, 1], [0, 0], transform=ax.transAxes, **kwargs)
ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)

which results in diagonal lines as originally intended. enter image description here

Upvotes: 1

Related Questions