skrowten_hermit
skrowten_hermit

Reputation: 445

Mismatch in y-axis scale in one or more of the subplots using pyplot in python

I'm trying to plot and compare the frequency spectrum of two .wav files. I wrote the following in python for that:

import pylab
import time
from scipy import fft, arange
from numpy import linspace
from scipy.io.wavfile import read
import gc
import sys

params = {'figure.figsize': (20, 15)}
pylab.rcParams.update(params)


def plotSpec(y, Fs):
    n = len(y)  # lungime semnal
    k = arange(n)
    T = n / Fs
    frq = k / T  # two sides frequency range
    frq = frq[range(n / 2)]  # one side frequency range
    ff_valu = fft(y) / n  # fft computing and normalization
    ff_valu = ff_valu[range(n / 2)]
    pylab.plot(frq, abs(ff_valu), 'r')  # plotting the spectrum
    pylab.tick_params(axis='x', labelsize=8)
    pylab.tick_params(axis='y', labelsize=8)
    pylab.tick_params()
    pylab.xticks(rotation=45)
    pylab.xlabel('Frequency')
    pylab.ylabel('Power')
    del frq, ff_valu, n, k, T, y
    gc.collect()
    return

def graph_plot(in_file, graph_loc, output_folder, count, func_type):
    graph_loc = int(graph_loc)
    rate = 0
    data = 0
    rate, data = read(in_file)
    dlen = len(data)
    print "dlen=", dlen
    lungime = dlen
    timp = dlen / rate
    print "timp=", timp
    t = linspace(0, timp, dlen)

    pylab.subplot(3, 2, graph_loc)
    pylab.plot(t, data)
    fl = in_file.split('/')
    file_name = fl[len(fl) - 1]
    pylab.title(file_name)
    pylab.tick_params(axis='x', labelsize=8)
    pylab.tick_params(axis='y', labelsize=8)
    pylab.xticks(rotation=45)
    pylab.xlabel('Time')
    pylab.ylabel('Numerical level')

    pylab.subplot(3, 2, graph_loc + 2)
    plotSpec(data, rate)

    pylab.subplot(3, 2, graph_loc + 4)
    if rate == 16000:
        frq = 16
    else:
        frq = 8
    pylab.specgram(data, NFFT=128, noverlap=0, Fs=frq)
    pylab.tick_params(axis='x', labelsize=8)
    pylab.tick_params(axis='y', labelsize=8)
    pylab.xticks(rotation=45)
    pylab.xlabel('Time')
    pylab.ylabel('Frequency')

    if graph_loc == 2:
        name = in_file.split("/")
        lnth = len(name)
        name = in_file.split("/")[lnth - 1].split(".")[0]
        print "File=", name
        if func_type == 'a':
            save_file = output_folder + 'RESULT_' + name + '.png'
        else:
            save_file = output_folder + 'RESULT_graph.png'
        pylab.savefig(save_file)
        pylab.gcf()
        pylab.gca()
        pylab.close('all')
        del in_file, graph_loc, output_folder, count, t, rate, data, dlen, timp
        gc.get_referrers()
        gc.collect()

def result_plot(orig_file, rec_file, output_folder, seq):
    graph_loc = 1
    graph_plot(orig_file, graph_loc, output_folder, seq, 'a')
    graph_loc = 2
    graph_plot(rec_file, graph_loc, output_folder, seq, 'a')
    sys.exit()


save_file="~/Documents/Output/"
o_file='~/Documents/audio/orig_8sec.wav'
#o_file='~/Documents/audio/orig_4sec.wav'
r_file='~/Documents/audio/rec_8sec.wav'
#r_file='~/Documents/audio/rec_4sec.wav'
print 10*"#"+"Start"+10*"#"
result_plot(o_file, r_file,save_file, 'a')
print 10*"#"+"End"+10*"#"
pylab.close('all')

With the above code, I see that the scale of y-axis appear different:

8sec

It clearly shows an automatically assigned scale. With this any amplification or attenuation with respect to the original file is difficult to be made obvious unless the person looks up the values.

Since I cannot really predict what would be the max amplitude among either files when I use multiple samples, how can I make both y-axis on each subplot set to the max of either so that the scale is the same and amplification is more clear?

Upvotes: 0

Views: 565

Answers (3)

skrowten_hermit
skrowten_hermit

Reputation: 445

Taking cues from other answers, I happened to make it work the following way:

import matplotlib.pyplot as pl
import time
from scipy import fft, arange
from numpy import linspace
from scipy.io.wavfile import read
import gc
import sys



def plotWavAmplLev(in_file, sub_graph):
    print "Printing Signal graph (amplitude vs seconds)...."
    rate, data = read(in_file)
    dlen = len(data)
    timp = dlen / rate
    t = linspace(0,timp,dlen)

    sub_graph.plot(t, data)

    fl = in_file.split('/')
    file_name = fl[len(fl) - 1]
    sub_graph.set_title(file_name)
    sub_graph.tick_params(axis='x', labelsize=10)
    sub_graph.tick_params(axis='y', labelsize=10)
    sub_graph.set_xlabel('Time')
    sub_graph.set_ylabel('Numerical level')


def plotSpectralDensity(y, fs, sub_graph):
    print "Printing Power Spectral Density (dB vs Hz)...."
    n = len(y)  # lungime semnal
    k = arange(n)
    T = n / fs
    frq = k / T  # two sides frequency range
    frq = frq[range(n / 2)]  # one side frequency range
    ff_valu = fft(y) / n  # fft computing and normalization
    ff_valu = ff_valu[range(n / 2)]
    sub_graph.plot(frq, abs(ff_valu), 'r')  # plotting the spectrum
    sub_graph.tick_params(axis='x', labelsize=10)
    sub_graph.tick_params(axis='y', labelsize=10)
    sub_graph.tick_params()
    sub_graph.set_xlabel('Frequency')
    sub_graph.set_ylabel('Power')
    del frq, ff_valu, n, k, T, y
    gc.collect()
    return


def plotSpectrogram(rate, data, sub_graph):
    print "Plotting Spectrogram (kHz vs seconds)...."
    if rate == 16000:
        frq = 16
    else:
        frq = 8
    sub_graph.specgram(data, NFFT=128, noverlap=0, Fs=frq)
    sub_graph.tick_params(axis='x', labelsize=10)
    sub_graph.tick_params(axis='y', labelsize=10)
    sub_graph.set_xlabel('Time')
    sub_graph.set_ylabel('Frequency')


def graph_plot(in_file_list, output_folder, func_type):
    orig_file = in_file_list[0]
    rec_file = in_file_list[1]
    g_index = 1
    g_rows = 3
    g_cols = 2

    fig, axes = pl.subplots(g_rows, g_cols, figsize=(20,15), sharex="row", sharey="row")

    for i, row in enumerate(axes):
        for j, col in enumerate(row):
            if i == 0 :
                if j == 0:
                    print "Source file waveform is being plotted...."
                    rate, data = read(orig_file)
                    plotWavAmplLev(orig_file, col)
                    continue
                elif j == 1:
                    print "Recorded file waveform is being plotted...."
                    rate, data = read(rec_file)
                    plotWavAmplLev(rec_file, col)
                    continue
            elif i == 1:
                if j == 0:
                    print "Source file PSD is being plotted...."
                    rate, data = read(orig_file)
                    plotSpectralDensity(data, rate, col)
                    continue
                elif j == 1:
                    print "Recorded file PSD is being plotted...."
                    rate, data = read(rec_file)
                    plotSpectralDensity(data, rate, col)
                    continue
            elif i == 2:
                if j == 0:
                    print "Source file Spectrogram is being plotted...."
                    rate, data = read(orig_file)
                    plotSpectrogram(rate, data, col)
                    continue
                elif j == 1:
                    print "Recorded file Spectrogram is being plotted...."
                    rate, data = read(rec_file)
                    plotSpectrogram(rate, data, col)
                    continue
    pl.tight_layout()

    name = in_file_list[1].split("/")
    lnth = len(name)
    name = in_file_list[1].split("/")[lnth - 1].split(".")[0]
    print "File=", name
    if func_type == 'a':
        save_file = output_folder + 'RESULT_' + name + '.png'
    else:
        save_file = output_folder + 'RESULT_graph.png'
    pl.savefig(save_file)
    pl.gcf()
    pl.gca()
    pl.close('all')
    del in_file_list, output_folder, rate, data
    gc.get_referrers()
    gc.collect()


def result_plot(orig_file, rec_file, output_folder, seq):
    flist = [orig_file, rec_file]
    graph_plot(flist, output_folder, 'a')


s_file="/<path>/Output/"
#o_file='/<path>/short_orig.wav'
o_file='/<path>/orig.wav'
#r_file='/<path>/short_rec.wav'
r_file='/<path>/rec.wav'
print 10*"#"+"Start"+10*"#"
result_plot(o_file, r_file,s_file, 'a')
print 10*"#"+"End"+10*"#"
pl.close('all')

Now, I got the y-axis scales fixed and get the output as follows:

enter image description here

This makes comparison a lot easier now.

Upvotes: 0

DavidG
DavidG

Reputation: 25362

An alternative to setting the limits yourself is to create the figure and axes first using

fig, axes = plt.subplots(3, 2)

This has an optional argument sharex. From the docs

sharex, sharey : bool or {'none', 'all', 'row', 'col'}, default: False

Controls sharing of properties among x (sharex) or y (sharey) axes:

        True or 'all': x- or y-axis will be shared among all subplots.
        False or 'none': each subplot x- or y-axis will be independent.
        'row': each subplot row will share an x- or y-axis.
        'col': each subplot column will share an x- or y-axis.

Therefore, we can make sure the rows share the same x axis values as each other by using the argument sharex="row":

fig, axes = plt.subplots(3, 2, sharex="row")

If you want the y axis to be shared you can use sharey="row" instead/aswell.

Upvotes: 1

Sheldore
Sheldore

Reputation: 39052

I am adding my explanation you asked for in the comments above as an answer below. The idea is to selectively modify the x-axis limits for some particular subplots

fig, axes = plt.subplots(2,3,figsize=(16,8))

x = np.linspace(0, 2*np.pi, 100)
y = np.sin(x)

for i, row in enumerate(axes):
    for j, col in enumerate(row):
        col.plot(x, y)
        col.set_title("Title here", fontsize=18)
        if i == 1 and (j == 1 or j == 2):
            col.set_xlim(0, np.pi)
plt.tight_layout()  

Output

enter image description here

Upvotes: 1

Related Questions