W. MacTurk
W. MacTurk

Reputation: 170

Matplotlib clearing old axis labels when re-plotting data

I've got a script wherein I have two functions, makeplots() which makes a figure of blank subplots arranged in a particular way (depending on the number of subplots to be drawn), and drawplots() which is called later, drawing the plots (obviously). The functions are posted below.

The script does some analysis of data for a given number of 'targets' (which can number anywhere from one to nine) and creates plots of the linear regression for each target. When there are multiple targets, this works great. But when there's a single target (i.e. a single 'subplot' in the figure), the Y-axis label overlaps the axis itself (this does not happen when there are multiple targets).

Ideally, each subplot would be square, no labels would overlap, and it would work the same for one target as for multiple targets. But when I tried to decrease the size of the y-axis label and shift it over a bit, it appears that the actual axes object was drawn over the previously blank, square plot (whose axes ranged from 0 to 1), and the old tick mark labels are still visible. I'd like to have those old tick marks removed when calling drawplots(). I've tried changing the subplot_kw={} arguments in makeplots, as well as removing ax.set_aspect('auto') from drawplots, both to no avail. Note that there are also screenshots of various behaviors at the end, also.

def makeplots(targets, active=actwindow):

    def rowcnt(y):
        rownumb = y//3 if (y%3 == 0) else y//3+1
        return rownumb

    def colcnt(x):
        if x <= 3: colnumb = x
        elif x == 4: colnumb = 2
        else: colnumb = 3
        return colnumb

    numsubs = len(targets)
    numrow, numcol = rowcnt(numsubs), colcnt(numsubs)

    if numsubs >= 1:
        if numsubs == 1:
            fig, axs = plt.subplots(num='LOD-95 Plots', nrows=1, ncols=1, figsize = [8,6], subplot_kw={'adjustable': 'box', 'aspect': 1})
            # changed 'box' to 'datalim'
        fig, axs = plt.subplots(num='LOD-95 Plots', nrows=numrow, ncols=numcol, figsize = [numcol*6,numrow*6], subplot_kw={'adjustable': 'box', 'aspect': 1})
        fig.text(0.02, 0.5, 'Probit score\n    $(\sigma + 5)$', va='center', rotation='vertical', size='16')
    else:
        raise ValueError(f'Error generating plots [call: makeplots({targets},{active}) - invalid numsubs value]')

    axs = np.ravel(axs)
    for i, ax in enumerate(axs):
        ax.set_title(f'Limit of Detection: {targets[i]}', size=11)
        ax.grid()
    return fig, axs

and

def drawplots(ax, dftables, color1, color2):
    y = dftables.probit
    y95 = 6.6448536269514722
    logreg = False
    regfun = lambda m, x, b : (m*x) + b
    regq = scipy.stats.linregress(dftables.qty,y)
    regl = scipy.stats.linregress(dftables.log_qty,y)
    if regq.rvalue**2 >= regl.rvalue**2:
        regression = regq
        x_label = 'input quantity'
        x = dftables.qty
    elif regq.rvalue**2 < regl.rvalue**2:
        regression = regl
        x_label = '$log_{10}$(input quantity)'
        x = dftables.log_qty
        logreg = True
    slope, intercept, r = regression.slope, regression.intercept, regression.rvalue
    r2 = r**2
    lod = (y95-intercept)/slope
    xr = [0, lod*1.2]
    yr = [intercept, regfun(slope, xr[1], intercept)]
    regeqn = "y = "+str(f"{slope:.2e}")+"x + "+str(f"{intercept:.3f}")

    if logreg:
        lodstr = f'log(LOD) = {lod:.2f}' if lod <= 100 else f'log(LOD) = {lod:.2e}'
    elif not logreg:
        lodstr = f'LOD = {lod:.2f}' if lod <= 100 else f'LOD = {lod:.2e}'
#        raise ValueError(f'Error raised calling drawplots()')


    ax.set_xlabel(x_label, fontweight='bold')
    ax.plot(xr, yr, color=color1, linestyle='--') # plot regression line
    ax.plot(lod, y95, marker='D', color=color2, markersize=7) # plot point for LoD
    ax.plot(xr, [y95,y95], color=color2, linestyle=':') # horizontal crosshair
    ax.plot([lod,lod],[0, 7.1], color=color2, linestyle=':') # vertical crosshair
    ax.scatter(x, y, s=81, color=color1, marker='.') # actual data points
    ax.annotate(f"{lodstr}", xy=(lod,0.1),
                xytext=(0.9*lod,0.5), fontsize=8, arrowprops = dict(facecolor='black', headlength=5, width=2, headwidth=5))
    ax.set_aspect('auto')
    ax.set_xlim(left=0)
    ax.set_ylim(bottom=0)
    ax.plot()
    if logreg: lod = 10 ** lod

    return r2, lod, regeqn, logreg

The context they're called in:

fig, axs = makeplots(targets)
wg.SetForegroundWindow(actwindow)

with open(outName, 'a+') as f:
    print(f"Lower Limit of Detection Analysis on {dt} at {tm}\n", file=f)
    for i, tars in enumerate(targets):
        data[tars] = stripThousands(data[tars])
#        logans = checkyn(f"Analyze {tars} using log10(concentration/quantity)? (y/n): ")
        for idx, val in enumerate(qtys):
            tables[i,idx,2] = hitrate(val,data,tars)
            tables[i,idx,3] = norm.ppf(tables[i,idx,2])+5

        printtables[tars] = pd.DataFrame(tables[i,:,:], columns=["qty","log_qty","probability","probit"])
        # construct dataframes from np.arrays and drop
        #     rows with infinite probit values:
        dftables[tars] = pd.DataFrame(tables[i,:,:], columns=["qty","log_qty","probability","probit"])
        dftables[tars].probit.replace([np.inf,-np.inf],np.nan, inplace=True)
        dftables[tars].dropna(inplace=True)


        r2, lod, eqn, logreg = drawplots(axs[i], dftables[tars], cbcolors[i], cbcolors[i+5])

No old tick marks, but y-label overlaps the y axis y-label is fine, but original 'default' ticks and labels visible and undesired on X axis (and at the top of the subplot) Works fine when analyzing multiple targets, though the y-label could be a little closer...

Upvotes: 0

Views: 1013

Answers (1)

tdy
tdy

Reputation: 41327

You should clear the axes in each iteration using pyplot.cla().

You posted a lot of code, so I'm not 100% sure of the best location to place it in your code, but the general idea is to clear the axes before each new plot.

Here is a minimal demo without cla():

x = [[1,2,3], [3,2,1]]

fig, ax = plt.subplots()
for index, data in enumerate(x):
    ax.plot(data)

without cla()

And with cla():

for index, data in enumerate(x):
    ax.cla()
    ax.plot(data)

with cla()

Upvotes: 1

Related Questions