Nathan Evans
Nathan Evans

Reputation: 51

Spectrogram generation in java using FFT on a .wav file not producing expected output

So I am making an AI project that classifies speech into either "up", "down", "left", right or background noise, and from this, a character in a videogame is moved.

I have made an FFT algorithm deriving it from the mathematical explanation, which I believe is correct as I have tested its output against that from this site (https://engineering.icalculator.info/discrete-fourier-transform-calculator.html)

I then have tried to generate a spectrogram and have used code based on the code from the main function of the App class from this site (Creating spectrogram from .wav using FFT in java)

I tested my code on a .wav file of me saying hello and the spectrogram generated is not what I was expecting out, see below the difference between my java made spectrogram and my python made spectrogram (ignore the colour difference).

Java Spectrogram

Python Spectrogram

New Java Spectrogram with SleuthEyes help

Here is the original code I have used/written:

package STACKOVERFLOW;

import com.company.Complex;

import javax.imageio.ImageIO;
import java.awt.*;
import java.awt.image.BufferedImage;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Scanner;

public class StackOverFlow {
    private static Color getColour(double power) {
        var H = power * 0.4;
        var S = 1.0;
        var B = 1.0;
        return Color.getHSBColor((float) H, (float) S, (float) B);
    }

    private static double[] getAudioData(String filePath) {
        var path = Paths.get(filePath);
        try {
            var entireFileData = Files.readAllBytes(path);
            var rawData = Arrays.copyOfRange(entireFileData, 44, entireFileData.length);
            var length = rawData.length;

            var newLength = length / 4;
            var dataMono = new double[newLength];

            double left, right;
            for (int i = 0; 2 * i + 3< newLength; i++) {
                left = (short) ((rawData[2 * i + 1] & 0xff) << 8) | (rawData[2 * i] & 0xff);
                right = (short) ((rawData[2 * i + 3] & 0xff) << 8) | (rawData[2 * i + 2] & 0xff);
                dataMono[i] = (left + right) / 2.0;
            }

            return dataMono;
        } catch (IOException e) {
            e.printStackTrace();
        }
        return null;
    }

    private static Complex[] toComplex(double[] samples) {
        var l = samples.length;
        var cOut = new Complex[l];
        for (int i = 0; i < l; i++) {
            cOut[i] = new Complex(samples[i], 0);
        }
        return cOut;
    }

    private static double modulusSquared(Complex a) {
        var real = a.getReal();
        var imaginary = a.getImag();
        return (real * real) + (imaginary * imaginary);
    }

    private static Complex[] fft(Complex[] samples) {
        var N = samples.length; // number of samples
        if (N == 1) return samples; // stops the recursive splits on the samples
        // TODO: M only works for N a power of 2
        var M = N / 2; // middle index of the samples
        var Xeven = new Complex[M]; // array for even split
        var Xodd = new Complex[M]; // array for odd split

        // splits the samples
        for (int i = 0; i < M; i++) {
            Xeven[i] = samples[2 * i];
            Xodd[i] = samples[2 * i + 1];
        }

        // recursive calls on even and odd samples
        var Feven = new Complex[M];
        Feven = fft(Xeven);
        var Fodd = new Complex[M];
        Fodd = fft(Xodd);

        var frequencyBins = new Complex[N];

        for (int i = 0; i < (N / 2); i++) {
            var cExponential = Complex.multiply(
                    Complex.polar(1, -2 * Math.PI * i / N),
                    Fodd[i]
            );

            frequencyBins[i] = Complex.add(
                    Feven[i],
                    cExponential
            );

            frequencyBins[i + N / 2] = Complex.sub(
                    Feven[i],
                    cExponential
            );
        }
        return frequencyBins;
    }

    public static void makeSpectrogram() {
        var scan = new Scanner(System.in);
        System.out.println("Enter file path: ");
        var filePath = scan.nextLine();
        var rawAudioData = getAudioData(filePath);
        assert rawAudioData != null;
        var length = rawAudioData.length;
        var complexAudioData = toComplex(rawAudioData);

        // parameters for FFT
        var windowSize = 256;
        var overlapFactor = 2;
        var windowStep = windowSize / overlapFactor;

        // plotData array
        var nX = (length - windowSize) / windowStep;
        var nY = (windowSize / 2);
        var plotData = new double[nX][nY];

        // amplitudes to normalise
        var maxAmplitude = Double.MIN_VALUE;
        var minAmplitude = Double.MAX_VALUE;
        double amplitudeSquared;

        // application of the FFT
        for (int i = 0; i < nX; i++) {
            var windowSizeArray = fft(Arrays.copyOfRange(complexAudioData, i * windowStep, i * windowStep + windowSize));
            for (int j = 0; j < nY; j++) {
                amplitudeSquared = modulusSquared(windowSizeArray[2 * j]);
                if (amplitudeSquared == 0.0) {
                    plotData[i][nY - j - 1] = amplitudeSquared;
                } else {
                    var threshold = 1.0; // prevents log(0)
                    plotData[i][nY - j - 1] = 10 * Math.log10(Math.max(amplitudeSquared, threshold));
                }

                // find min and max amplitudes
                if (plotData[i][j] > maxAmplitude) {
                    maxAmplitude = plotData[i][j];
                } else if (plotData[i][j] < minAmplitude) {
                    minAmplitude = plotData[i][j];
                }
            }
        }

        // normalisation
        var difference = maxAmplitude - minAmplitude;
        for (int i = 0; i < nX; i++) {
            for (int j = 0; j < nY; j++) {
                plotData[i][j] = (plotData[i][j] - minAmplitude) / difference;
            }
        }

        // plot the spectrogram
        var spectrogram = new BufferedImage(nX, nY, BufferedImage.TYPE_INT_RGB);
        double ratio;
        for (int i = 0; i < nX; i++) {
            for (int j = 0; j < nY; j++) {
                ratio = plotData[i][j];
                var colour = getColour(1.0 - ratio);
                spectrogram.setRGB(i, j, colour.getRGB());
            }
        }

        // write the image to a file
        try {
            var outputFile = new File("saved.png");
            ImageIO.write(spectrogram, "png", outputFile);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static void main(String[] args) {
        makeSpectrogram();
    }
}

Here is the Complex class that is used above:

package com.company;

import java.text.DecimalFormat;

public class Complex {

    private final static DecimalFormat df2 = new DecimalFormat("#.##");

    private double r;
    private double i;

    public Complex(double r, double i) {
        this.r = r;
        this.i = i;
    }

    @Override
    public String toString() {
        return "(" + df2.format(this.r) + ", " + df2.format(this.i) + "i) ";
    }

    public double abs() {
        return Math.hypot(this.r, this.i);
    }

    public double getReal() {
        return this.r;
    }

    public double getImag() {
        return this.i;
    }

    public void setReal(double r) {
        this.r = r;
    }

    public void setImag(double i) {
        this.i = i;
    }

    public static Complex polar(double r, double theta) {
        return new Complex(
                r * Math.cos(theta),
                r * Math.sin(theta)
        );
    }

    public static Complex multiply(Complex a, Complex b) {
            /*
             (a + bi) * (c + di) =
             ac + adi + cbi + -bd =
             (ac - bd) + (ad + cb)i
            */
        var real = (a.r * b.r) - (a.i * b.i);
        var imag = (a.r * b.i) + (a.i * b.r);
        return new Complex(real, imag);
    }

    public static Complex add(Complex a, Complex b) {
        return new Complex(
                a.r + b.r,
                a.i + b.i
        );
    }

    public static Complex sub(Complex a, Complex b) {
        return new Complex(
                a.r - b.r,
                a.i - b.i
        );
    }
}

any guidance would be appreciated

Upvotes: 2

Views: 714

Answers (1)

SleuthEye
SleuthEye

Reputation: 14577

Reading the .wav file

The .wav file decoding included in that other question you linked is hardly a full blown decoder. It accounts for the OP's specific stereo 2bytes-per-sample use-case.

It looks like you stumbled upon other decoding issues while trying to adapt it to a different use case. As a general piece of advice, I'd suggest to use a more complete .wav decoder which would take into account the number of channels, the number of bytes-per-sample, etc.

If on the other hand you want to craft your own decoder (for example as a learning exercise), then a slightly more robust implementation may look like the following:

public short getShort(byte[] buffer, int offset) {
  return (short) ((buffer[offset + 1] & 0xff) << 8) | (buffer[offset] & 0xff);
}
public int getNumberOfChannels(byte[] entireFileData){
  return (int) getShort(entireFileData, 22);
}
public int getBytesPerSample(byte[] entireFileData){
  return (int) getShort(entireFileData, 34)/8;
}

private static double[] getAudioData(String filePath) {

    ...
    var entireFileData = Files.readAllBytes(path);
    var rawData = Arrays.copyOfRange(entireFileData, 44, entireFileData.length);
    var length = rawData.length;

    int numChannels    = getNumberOfChannels(entireFileData);
    int bytesPerSample = getBytesPerSample(entireFileData);
    int newLength      = length / (bytesPerSample*numChannels);
    var dataMono       = new double[newLength];
    if (2 == bytesPerSample) {
      for (int i = 0; 2*numChannels*(i+1)-1 < length; i++) {
        double sum = 0.0;
        for (int j = 0; j < numChannels; j++) {
          sample = (short) ((rawData[2*numChannels*i + 2*j + 1] & 0xff) << 8) | (rawData[2*numChannels*i + 2*j] & 0xff);
          sum += sample;
        }
        dataMono[i] = sum / numChannels;
      }
    }
    else { 
    ... // handle different number of bytes per sample
    }
}

Note that it still only covers 16bit PCM samples, assumes a fixed header structure (see this tutorial, but the .wav file format is actually more flexible), and would get tripped on files with extension chunks.

Processing the spectrum

The FFT library used in that other question you linked returns an array of double which is to be interpreted as interleaved real and imaginary parts of the actual complex values. Correspondingly the indexing used to perform the magnitude computations is using pairs of elements at index 2*j and 2*j+1. On the other hand your implementation obtains complex values directly, so you should not be skipping over values with the 2* factor and instead use:

for (int j = 0; j < nY; j++) {
  amplitudeSquared = modulusSquared(windowSizeArray[j]);
  ...
}

Upvotes: 1

Related Questions