Mark Giaconia
Mark Giaconia

Reputation: 3953

Deeplearning4J RNN Training : Exception 3D input expected to RNN layer expected, got 2

with the following code (tweaked for hours with different params), I keep getting an exception java.lang.IllegalStateException: 3D input expected to RNN layer expected, got 2

What I am trying to accomplish is to train a RNN to predict the next value (double) in a sequence based on a bunch of training sequences. I am generating the features with a simple random data generator, and using the last val in a sequence as the training label (in this case predicted value).

my code:

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;

import java.util.Random;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public class RnnPredictionExample {

  public static void main(String[] args) {
    //generate 100 rows of data that have 50 columns/features each
    DataSet trainingdata = getRandomDataset(100, 51, 1);
    // Train the RNN model...
    MultiLayerNetwork trainedModel = trainRnnModel(trainingdata, 50, 10, 1);

    // generate a sequence, and Perform next value prediction on the sequence
    double[] inputSequence = randomData(50, 1);
    double predictedValue = predictNextValue(trainedModel, inputSequence);
    System.out.println("Predicted Next Value: " + predictedValue);
  }

  public static MultiLayerNetwork trainRnnModel(DataSet trainingdataandlabels, int sequenceLength, int numHiddenUnits, int numEpochs) {
    // ... Create network configuration ...

    // Create and initialize the network
    MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
            //.seed(123)
            .list()
            .layer(new LSTM.Builder()
                    .nIn(1)
                    .nOut(50)
                    .build()
            )
            .layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE)
                    .activation(Activation.IDENTITY)
                    .nIn(50)
                    .nOut(1) // Set nOut to 1
                    .build()
            )
            .build();
    MultiLayerNetwork net = new MultiLayerNetwork(config);
    net.init();

    for (int i = 0; i < numEpochs; i++) {
      net.fit(trainingdataandlabels);
    }

    return net;
  }

  public static double predictNextValue(MultiLayerNetwork trainedModel, double[] inputSequence) {
    INDArray inputArray = Nd4j.create(inputSequence);
    INDArray predicted = trainedModel.rnnTimeStep(inputArray);

    // Predicted value is the last element of the predicted sequence
    return predicted.getDouble(predicted.length() - 1);
  }

  static Random random = new Random();

  public static double[] randomData(int length, int rangeMultiplier) {

    double[] out = new double[length];
    for (int i = 0; i < out.length; i++) {
      out[i] = random.nextDouble() * rangeMultiplier;
    }
    return out;
  }

  //assumes labes is the last val in each sequence
  public static DataSet getRandomDataset(int numRows, int lengthEach, int rangeMultiplier) {
    INDArray training = Nd4j.zeros(numRows, lengthEach - 1);
    INDArray labels = Nd4j.zeros(numRows, 1);

    for (int i = 0; i < numRows; i++) {
      double[] randomData = randomData(lengthEach, rangeMultiplier);
      for (int j = 0; j < randomData.length - 1; j++) {
        training.putScalar(new int[]{i, j}, randomData[j]);
      }
      labels.putScalar(new int[]{i, 0}, randomData[randomData.length - 1]);

    }

    return new DataSet(training, labels);

  }
}

thanks

For those interested, I made the changes based on the accepted answer, and here is the working code again in entirety

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;

import java.util.Random;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.LSTM;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.lossfunctions.LossFunctions;

public class RnnPredictionExample {

  public static void main(String[] args) {
    //generate 100 rows of data that have 50 columns/features each
    DataSet trainingdata = getRandomDataset(100, 51, 1);
    // Train the RNN model...
    MultiLayerNetwork trainedModel = trainRnnModel(trainingdata, 50, 10, 1);

    // generate a sequence, and Perform next value prediction on the sequence
    double[] inputSequence = randomData(50, 1);
    double predictedValue = predictNextValue(trainedModel, inputSequence);
    System.out.println("Predicted Next Value: " + predictedValue);
  }

  public static MultiLayerNetwork trainRnnModel(DataSet trainingdataandlabels, int sequenceLength, int numHiddenUnits, int numEpochs) {
    // ... Create network configuration ...

    // Create and initialize the network
    MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()
            //.seed(123)
            .list()
            .layer(new LSTM.Builder()
                    .nIn(50)
                    .nOut(1)
                    .build()
            )
            .layer(new RnnOutputLayer.Builder(LossFunctions.LossFunction.MSE)
                    .activation(Activation.IDENTITY)
                    .nIn(1)
                    .nOut(1) // Set nOut to 1
                    .build()
            )
            .build();
    MultiLayerNetwork net = new MultiLayerNetwork(config);
    net.init();

    for (int i = 0; i < numEpochs; i++) {
      net.fit(trainingdataandlabels);
    }

    return net;
  }

  public static double predictNextValue(MultiLayerNetwork trainedModel, double[] inputSequence) {
    // INDArray inputArray = Nd4j.create(inputSequence);
    INDArray inputArray = Nd4j.create(inputSequence).reshape(1, inputSequence.length, 1);
    INDArray predicted = trainedModel.rnnTimeStep(inputArray);

    // Predicted value is the last element of the predicted sequence
    return predicted.getDouble(predicted.length() - 1);
  }

  static Random random = new Random();

  public static double[] randomData(int length, int rangeMultiplier) {

    double[] out = new double[length];
    for (int i = 0; i < out.length; i++) {
      out[i] = random.nextDouble() * rangeMultiplier;
    }
    return out;
  }

  //assumes labes is the last val in each sequence
  public static DataSet getRandomDataset(int numRows, int lengthEach, int rangeMultiplier) {
    //INDArray training = Nd4j.zeros(numRows, lengthEach - 1);
    INDArray training = Nd4j.zeros(numRows, lengthEach - 1, 1);
    //INDArray labels = Nd4j.zeros(numRows, 1);
    INDArray labels = Nd4j.zeros(numRows, 1, 1);

    for (int i = 0; i < numRows; i++) {
      double[] randomData = randomData(lengthEach, rangeMultiplier);
      for (int j = 0; j < randomData.length - 1; j++) {
        // training.putScalar(new int[]{i, j}, randomData[j]);
        training.putScalar(new int[]{i, j, 0}, randomData[j]);
      }
      //labels.putScalar(new int[]{i, 0}, randomData[randomData.length - 1]);
      labels.putScalar(new int[]{i, 0, 0}, randomData[randomData.length - 1]);
    }

    return new DataSet(training, labels);

  }
}

Upvotes: 0

Views: 273

Answers (1)

Paul Dubs
Paul Dubs

Reputation: 807

In DL4J you always have a batch of data in a DataSet object. That means, if your training data has the shape (n, f), it will be interpreted as n examples with f features per example.

An RNN expects several steps per example, that means your data needs to have the shape (n, f, t) so you have n examples, f features and t steps.

I suppose you are going for a batch size of n=1 examples. The easiest solution to your predicament is therefore going to call Nd4j.expandDims(arr, 0) to give it that additional singular dimension.

Upvotes: 1

Related Questions