erikd71
erikd71

Reputation: 279

deeplearning4j - using an RNN/LSTM for audio signal processing

I'm trying to train a RNN for digital (audio) signal processing using deeplearning4j. The idea is to have 2 .wav files: one is an audio recording, the second is the same audio recording but processed (for example with a low-pass filter). The RNN's input is the 1st (unprocessed) audio recording, the output is the 2nd (processed) audio recording.

I've used the GravesLSTMCharModellingExample from the dl4j examples, and mostly adapted the CharacterIterator class to accept audio data instead of text.

My 1st project to work with audio at all with dl4j is to basically do the same thing as GravesLSTMCharModellingExample but generating audio instead of text, working with 11025Hz 8 bit mono audio, which works (to some quite amusing results). So the basics wrt working with audio in this context seem to work.

So step 2 was to adapt this for audio processing instead of audio generation.

Unfortunately, I'm not having much success. The best it seems to be able to do is outputting a very noisy version of the input.

As a 'sanity check', I've tested using the same audio file for both the input and the output, which I expected to converge quickly to a model simply copying the input. But it doesn't. Again, after a long time of training, all it seemed to be able to do is produce a noisier version of the input.

The most relevant piece of code I guess is the DataSetIterator.next() method (adapted from the example's CharacterIterator class), which now look like this:

public DataSet next(int num) {
    if (exampleStartOffsets.size() == 0)
        throw new NoSuchElementException();

    int currMinibatchSize = Math.min(num, exampleStartOffsets.size());
    // Allocate space:
    // Note the order here:
    // dimension 0 = number of examples in minibatch
    // dimension 1 = size of each vector (i.e., number of characters)
    // dimension 2 = length of each time series/example
    // Why 'f' order here? See http://deeplearning4j.org/usingrnns.html#data
    // section "Alternative: Implementing a custom DataSetIterator"
    INDArray input = Nd4j.create(new int[] { currMinibatchSize, columns, exampleLength }, 'f');
    INDArray labels = Nd4j.create(new int[] { currMinibatchSize, columns, exampleLength }, 'f');

    for (int i = 0; i < currMinibatchSize; i++) {
        int startIdx = exampleStartOffsets.removeFirst();
        int endIdx = startIdx + exampleLength;

        for (int j = startIdx, c = 0; j < endIdx; j++, c++) {
            // inputIndices/idealIndices are audio samples converted to indices.
            // With 8-bit audio, this translates to values between 0-255.
            input.putScalar(new int[] { i, inputIndices[j], c }, 1.0);
            labels.putScalar(new int[] { i, idealIndices[j], c }, 1.0);
        }
    }

    return new DataSet(input, labels);
}

So maybe I'm having a fundamental misunderstanding of what LSTMs are supposed to do. Is there anything obviously wrong in the posted code that I'm missing? Is there an obvious reason why training on the same file doesn't necessarily converge quickly to a model that just copies the input? (let alone even trying to train it on signal processing that actually does something?)

I've seen Using RNN to recover sine wave from noisy signal which seems to be about a similar problem (but using a different ML framework), but that didn't get an answer.

Any feedback is appreciated!

Upvotes: 26

Views: 1859

Answers (3)

Saksham Gupta
Saksham Gupta

Reputation: 116

The most common issue with problems like this is the training data.

  1. Make sure you have enough training data available. If you don't, you can use a library like audiomentations to augment your training set.
  2. Diversity of training data. The more perturbations you can add to your training set, the better.
  3. Hyperparameter optimization - Neural networks in general require a lot of parameter tuning to be able to perform above average. Parameter optimization in deeplearning4j
  4. This is a suggestion based on past experience. It may be out of scope, but an autoencoder architecture usually does wonders for these processing use-cases. (Audio, Images, etc.)

Upvotes: 1

Borislav Markov
Borislav Markov

Reputation: 1745

If you hear distorted version of the input you are on the right way.

The problem might be that your free parameters of the network cannot generalize well on small number of examples. Make sure you have more samples, at least 50_000 which does not overlap each other (not from the same wav file) and try to play with network params, for example try to reduce the nodes on each layer with 10-15% and try with lower learning rate.

Upvotes: 1

Andr&#233; Abboud
Andr&#233; Abboud

Reputation: 2040

hello i think in logic for a dataset try to use a long type instead of an integer

public DataSet next(int num)

replace to

public DataSet next(long num)

Upvotes: -1

Related Questions