Reputation: 839
According to Tensorflow's official website,(https://www.tensorflow.org/api_docs/python/tf/contrib/rnn/BasicLSTMCell#zero_state) zero_state has to specify a batch_size. Many examples I found use this code:
init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)
outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in,
initial_state=init_state, time_major=False)
For training steps, it's okay to fix the batch size. However, when predicting, the test set might not have the same shape as the training set's batch size. For example, one batch of my training data has shape [100, 255, 128]. The batch size is 100, with 255 steps and 128 inputs. While the test set is [2000, 255, 128]. I can't predict since in dynamic_rnn(initial_state), it already set a fixed batch_size = 100. How do I fix this?
Thanks.
Upvotes: 8
Views: 3236
Reputation: 17468
As @陈狗蛋 answered, there is no need to set initial_state
in tf.compat.v1.nn.dynamic_rnn
because it is optional. You can simply do like this
outputs, final_state = tf.nn.dynamic_rnn(lstm_cell,
X_inputs,
time_major=False,
dtype=tf.float32)
Do not forget to set the dtype
, here I set tf.float32
, you can set the dtype
as you need.
As the Docs of tf.compat.v1.nn.rnn_cell.LSTMCell
says:
batch_size: int, float, or unit Tensor representing the batch size
The batch_size
must be an explicit value. So, using a placeholder for the batch_size
argument is a workaround but not a recommended method. I recommend you do not use it because it may be an invalid way in future versions.
Upvotes: 0
Reputation: 11
There is a fairly simple implementation. Just remove the initial_state! It is because that the initialization process may pre-allocates a batch-sized memory.
Upvotes: 1
Reputation: 2363
You can specify the batch_size
as a placeholder, not a constant. Just make sure to feed the relevant number in feed_dict
, which will be different for training and for testing
Importantly, specify []
as dimensions for the placeholder, because you might get errors if you specify None
, as is customary elsewhere. So something like this should work:
batch_size = tf.placeholder(tf.int32, [], name='batch_size')
init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)
outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in,
initial_state=init_state, time_major=False)
# rest of your code
out = sess.run(outputs, feed_dict={batch_size:100})
out = sess.run(outputs, feed_dict={batch_size:10})
Obviously make sure that the batch parameter matches the shape of your inputs, which dynamic_rnn
will interpret as [batch_size, seq_len, features]
or [seq_len, batch_size, features]
if time_major
is set to True
Upvotes: 12