npit
npit

Reputation: 2359

LSTM initial state for each item in the batch in tensorflow

I am using tf.nn.dynamic_rnn to run an LSTM in tensorflow. I have a tensor of N initial state vectors and a tensor of M = N * n inputs. Each series consists of n input items, and I want to evaluate the i-th set of input vectors with the i-th initial state vector, as shown below:

inputs[0:n], initial_states[0]
inputs[n:2*n], initial_states[1]
...

Is there a way to do it with a single call tf.nn.dynamic_rnn and the above tensors directly, or do I have to resort to a loop for each initial state vector and its corresponding inputs (resulting in len(initial_states) calls to tf.nn.dynamic_rnn)?

Upvotes: 0

Views: 236

Answers (1)

Allen Lavoie
Allen Lavoie

Reputation: 5808

(Adding a bit of detail from comments on the question)

This kind of batching is well supported, and is typically necessary to get good performance. Your initial_state will have a batch dimension over N, and the RNN will run for n steps on those batches. You just need to reshape inputs to be [N, n, ...] (with time_major=False, the default).

It gets trickier when you have variable-length inputs which need to be batched together. Something like SequenceQueueingStateSaver can help there.

Upvotes: 1

Related Questions