Rudresh Panchal
Rudresh Panchal

Reputation: 1000

Google TensorFlow based seq2seq model crashes while training

I have been trying to use Google's RNN based seq2seq model.

I have been training a model for text summarization and am feeding in a textual data approximately of size 1GB. The model quickly fills up my entire RAM(8GB), starts filling up even the swap memory(further 8GB) and crashes post which I have to do a hard shutdown.

The configuration of my LSTM network is as follows:

model: AttentionSeq2Seq
model_params:
  attention.class: seq2seq.decoders.attention.AttentionLayerDot
  attention.params:
    num_units: 128
  bridge.class: seq2seq.models.bridges.ZeroBridge
  embedding.dim: 128
  encoder.class: seq2seq.encoders.BidirectionalRNNEncoder
  encoder.params:
    rnn_cell:
      cell_class: GRUCell
      cell_params:
        num_units: 128
      dropout_input_keep_prob: 0.8
      dropout_output_keep_prob: 1.0
      num_layers: 1
  decoder.class: seq2seq.decoders.AttentionDecoder
  decoder.params:
    rnn_cell:
      cell_class: GRUCell
      cell_params:
        num_units: 128
      dropout_input_keep_prob: 0.8
      dropout_output_keep_prob: 1.0
      num_layers: 1
  optimizer.name: Adam
  optimizer.params:
    epsilon: 0.0000008
  optimizer.learning_rate: 0.0001
  source.max_seq_len: 50
  source.reverse: false
  target.max_seq_len: 50

I tried decreasing the batch size from 32 to 16, but it still did not help. What specific changes should I make in order to prevent my model from taking up the entirety of RAM and crashing? (Like decreasing data size, decreasing number of stacked LSTM cells, further decreasing batch size etc)

My system runs Python 2.7x, TensorFlow version 1.1.0, and CUDA 8.0. The system has an Nvidia Geforce GTX-1050Ti(768 CUDA cores) with 4GB of memory, and the system has 8GB of RAM and a further 8GB of swap memory.

Upvotes: 0

Views: 223

Answers (1)

Bo Shao
Bo Shao

Reputation: 143

You model looks pretty small. The only thing kind of big is the train data. Please check to make sure your get_batch() function has no bugs. It is possible that each batch you are actually loading the whole data set for training, in case there is a bug there.

In order to quickly prove this, just cut down your training data size to something very small (such as 1/10 of current size) and see if that helps. Note that it should not help because you are using mini batch. But if that resolve the problem, fix your get_batch() function.

Upvotes: 0

Related Questions