David
David

Reputation: 839

How to set Tensorflow dynamic_rnn, zero_state without a fixed batch_size?

According to Tensorflow's official website,(https://www.tensorflow.org/api_docs/python/tf/contrib/rnn/BasicLSTMCell#zero_state) zero_state has to specify a batch_size. Many examples I found use this code:

    init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)

    outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in, 
        initial_state=init_state, time_major=False)

For training steps, it's okay to fix the batch size. However, when predicting, the test set might not have the same shape as the training set's batch size. For example, one batch of my training data has shape [100, 255, 128]. The batch size is 100, with 255 steps and 128 inputs. While the test set is [2000, 255, 128]. I can't predict since in dynamic_rnn(initial_state), it already set a fixed batch_size = 100. How do I fix this?

Thanks.

Upvotes: 8

Views: 3236

Answers (3)

GoingMyWay
GoingMyWay

Reputation: 17468

As @陈狗蛋 answered, there is no need to set initial_state in tf.compat.v1.nn.dynamic_rnn because it is optional. You can simply do like this

outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, 
                                         X_inputs, 
                                         time_major=False, 
                                         dtype=tf.float32)

Do not forget to set the dtype, here I set tf.float32, you can set the dtype as you need.

As the Docs of tf.compat.v1.nn.rnn_cell.LSTMCell says:

batch_size: int, float, or unit Tensor representing the batch size

The batch_size must be an explicit value. So, using a placeholder for the batch_size argument is a workaround but not a recommended method. I recommend you do not use it because it may be an invalid way in future versions.

Upvotes: 0

陈狗蛋
陈狗蛋

Reputation: 11

There is a fairly simple implementation. Just remove the initial_state! It is because that the initialization process may pre-allocates a batch-sized memory.

Upvotes: 1

VS_FF
VS_FF

Reputation: 2363

You can specify the batch_size as a placeholder, not a constant. Just make sure to feed the relevant number in feed_dict, which will be different for training and for testing

Importantly, specify [] as dimensions for the placeholder, because you might get errors if you specify None, as is customary elsewhere. So something like this should work:

batch_size = tf.placeholder(tf.int32, [], name='batch_size')
init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)
outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in, 
        initial_state=init_state, time_major=False)
# rest of your code
out = sess.run(outputs, feed_dict={batch_size:100})
out = sess.run(outputs, feed_dict={batch_size:10})

Obviously make sure that the batch parameter matches the shape of your inputs, which dynamic_rnn will interpret as [batch_size, seq_len, features] or [seq_len, batch_size, features] if time_major is set to True

Upvotes: 12

Related Questions