ItsAmy
ItsAmy

Reputation: 844

Tensorflow GPU Memory exhausted during mean squared error

I have a Tensorflow model with is a recurrent neural network using long short term memory. The state size is 3000, each time step of input has 300 inputs, there are about 500 time steps, and 1 output for each time step. I am training a sequence-to-sequence model.

It runs fine for inputs with less than 500 time steps, but somewhere around 500 timesteps, it crashes with the following out of memory error:

ResourceExhaustedError (see above for traceback): OOM when allocating tensor with shape[20375,20375]
     [[Node: gradients/mean_squared_error/Mul_grad/mul_1 = Mul[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"](mean_squared_error/Square, gradients/mean_squared_error/Sum_grad/Tile)]]
     [[Node: gradients/MatMul_grad/tuple/control_dependency_1/_225 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_5086_gradients/MatMul_grad/tuple/control_dependency_1", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

And this is running on a GPU with 12gb of memory.

I have tried running it on my laptop cpu, and it seems to use very little memory (about 1 to 2 gb), but it's so slow that it never did get to 500 time steps. I'm working on some changes that will make it skip to 500 time steps to see how much memory it uses when not running on a GPU.

My questions is: Where could Tensorflow possibly want to allocate a tensor of shape [20375, 20375]? It seems to be related to the tf.mean_squared_error function, but that doesn't seem like an operation that should require such exorbitant amounts of memory.

I have tried reducing the batch size, but that just pushes the failure point up to a few more time steps, and I'll need up to a few thousand time steps, so this doesn't seem like a good long-term solution. I'd prefer to get to the root of the problem.

Here is the relevant code for the mean squared error:

initial_state_tuple = tf.contrib.rnn.LSTMStateTuple(initial_state, initial_hidden_state)


# Create the actual RNN
with tf.variable_scope(VARIABLE_SCOPE, reuse=None):
    cell = tf.contrib.rnn.BasicLSTMCell(STATE_SIZE)
    rnn_outputs, finalstate = tf.nn.dynamic_rnn(cell=cell, inputs=networkinput,
                                                initial_state=initial_state_tuple)

with tf.variable_scope(VARIABLE_SCOPE, reuse=True):
    weights = tf.get_variable(name=WEIGHTS_NAME, shape=[STATE_SIZE, 1], dtype=tf.float32)
    biases = tf.get_variable(name=BIASES_NAME, shape=[1], dtype=tf.float32)

# Build the output layers
rnn_outputs_reshaped = tf.reshape(rnn_outputs, [-1, STATE_SIZE])
network_outputs = tf.sigmoid(tf.matmul(rnn_outputs_reshaped, weights) + biases)
expected_outputs_reshaped = tf.reshape(expected_outputs, [-1, 1])

# Loss mask just cancels out the inputs that are padding characters, since not all inputs have the same number of time steps
loss_mask_reshaped = tf.reshape(loss_mask, shape=[-1])

expected_outputs_reshaped = loss_mask_reshaped * expected_outputs_reshaped
network_outputs = loss_mask_reshaped * network_outputs

loss = tf.losses.mean_squared_error(labels=expected_outputs_reshaped, predictions=network_outputs)

If you want all of the code, it can be found here. The relevant functions are buildtower() and buildgraph(). The constants NUM_GPUS and BATCH_SIZE are set to appropriate values when running on the machine with the GPUs.

Update: I replaced the line

loss = tf.losses.mean_squared_error(labels=expected_outputs_reshaped, predictions=network_outputs)

with

error_squared = tf.pow(expected_outputs_reshaped - network_outputs, 2)
loss = tf.reduce_mean(error_squared)

and the same error happened. I reduced the state size to 30 and the batch size to 5, and the error still happened, although it did make it up to about 3000 time steps.

Update: After doing some research, I have found that, when training an RNN with a large number of time steps, truncated backpropagation is often used. This leads me to believe that backpropagation through a large number of time steps inherently takes a lot of memory, and my issue is not that I've constructed my graph wrong, but that I have a fundamental misunderstanding of the resource requirements of gradient calculations. To this end, I am working on changing my code to use truncated backpropagation. I will report back with results.

Upvotes: 1

Views: 973

Answers (1)

ItsAmy
ItsAmy

Reputation: 844

This project is my first experience with machine learning and Tensorflow, and after doing some research, it seems I had some fundamental misunderstandings.

I had thought that memory usage would scale linearly with the number of time steps in my data. Because every other dimension of my model (Batch size, state size) was small, I expected that I could get up to quite a few time steps before running out of memory. However, it seems that memory usage of computing the gradients scales exponentially with the number of time steps, so no matter how small I made the state size and batch size, it would eventually exhaust all my memory because of the large number of time steps.

To deal with this, I am using truncated backpropagation, in which each batch is broken up into chunks of some fixed number of time steps. This is not perfect, because it means that errors can only be propagated back at most this many time steps. However, based on what I've found online, it seems to work well enough, and there's not too many other ways to get around the memory usage issue.

As I said before, this is all my first experience with machine learning, so if anything in here is blatantly wrong, please tell me.

Upvotes: 1

Related Questions