Silver
Silver

Reputation: 1468

Tensorflow Changing Batch Size for RNN During Text Generation

I built a vanilla character level RNN and trained it on some data. Everything worked fine up till there.

But now I want to use the model to generate text. The problem is that during this text-generation phase, the batch_size is 1, and the num_steps per batch are also different.

This is leading to several errors and I tried some hacky fixes but they aren't working. What's the usual way to deal with this?

Edit: More specifically my input placeholders have a shape of [None, num_steps], but the problem is with the initial state which doesn't accept a shape of [None, hidden_size].

Upvotes: 5

Views: 1787

Answers (3)

Amir
Amir

Reputation: 16587

As mentioned in chasep255 solution, the two tricky parts are: initial_state and batch_size & sequence lenght.

First tricky part:

If we set batch_size and sequence len to None we can change it during inference. Our first step is to define input shape as [None, None]:

self.inputs = tf.placeholder(tf.int32, shape=(None, None), name='inputs')
self.targets = tf.placeholder(tf.int32, shape=(None, None), name='targets')

Second tricky part:

Next step is to define dynamic initial_state. For this part, as mentioned in chasep255 solution, we can use placeholder where we, ourself, passed the zero_state to RNN. To this end, I used tf.shape API to get the different batch size based on input sequence (in my case: self.inp):

 cells = [tf.nn.rnn_cell.GRUCell(self.rnn_size)] * self.layer_size
 rnn_cell = tf.nn.rnn_cell.MultiRNNCell(cells)
 self.init_state = rnn_cell.zero_state(tf.shape(self.inp)[0], tf.float32)
 self.rnn_outputs, self.final_state = tf.nn.dynamic_rnn(rnn_cell, self.inp,
                                                                   initial_state=self.init_state,
                                                                   dtype=tf.float32)

Now in training, I run sess.run() twice. First, to fill initial_state with zero values. To this end, I used an array of size [training_batch_size * hidden_lstm_size] with zero values to pass it to placeholder. Second, I used passed state to next timestep again with placeholder something like:

  new_state = sess.run(self.initial_state,
                                 feed_dict={self.inputs: np.zeros([self.batch_size_in_train, lstm_hidden_size], dtype=np.int32)})

  for x, y in batch_gen:

      feed_dict = {
                    self.inputs: x,
                    self.targets: y,
                    self.initial_state: new_state
                }
                _, step, new_state, loss = sess.run([self.optimizer, 
                                                   self.global_step, 
                                                   self.final_state, 
                                                   self.loss],
                                                   feed_dict) 

In Inference we can do the same thing. This time we fill initial_state with zero values of size [1 * 1]. Our inference part would be:

new_state = sess.run(self.initial_state, feed_dict={self.inputs: np.zeros([1, 1], dtype=np.int32)})
        for i in range(400):
            x = np.zeros((1, 1))
            x[0, 0] = c
            feed_dict = {
                self.inputs: x,
                self.keep_prob: 1,
                self.initial_state: new_state
            }
            preds, new_state = sess.run(
                [self.prediction, self.final_state],
                feed_dict=feed_dict)

See complete code here.

Upvotes: 0

Cauchyzhou
Cauchyzhou

Reputation: 31

How about use tf's reuse.

class Model():
     def __init__(self,batch_size,reuse)
          self.batch_size = batch_size
          self.reuse = reuse
          self.input_x = tf.placeholder(.....)
          self.input_y = tf.placeholder(.....)
     def inference(self)
          with tf.variable_scope('xxx',reuse=self.reuse)
               ...
               cell = tf.contrib.rnn.LSTMCell(xxx,reuse=self.reuse)
               init_state = cell.zero_state(self.batch_size, dtype=tf.float32)
               ...
     def train_op(self):
         ....

if __name__ == '__main__':
      train_model = model(batch=128,reuse=False)
      test_model = model(batch=1,reuse=True)
      with tf.Session() as sess:
           sess.run(train_model.train_op,feed_dict={...})
           sess.run(test_model.prediction,feed_dict={...})

Of course, it looks like define 2 branch in a tf graph, and maybe look not very good. But if you don't want to pass init_state of RNN Cell , it is a way.

Upvotes: 0

chasep255
chasep255

Reputation: 12175

I have dealt with this same problem. There are two issues you will need to deal with. The first is adjusting the batch size and step size to 1. You can easily do this by setting the batch and length dimensions in the input sequence to none. Ie [None, None, 128], the 128 represents the 128 ascii characters (although you could probably use less since you probably only need a subset of the characters.)

Dealing with the initial state is the most trickey. This is because you need to save it between calls to session.run(). Since your num_steps is one, and it is initialized to zero at the beginning of each step. What I recommend doing is allowing the initial state to be passed as a placeholder and returned from session.run(). This way the user of the model can continue the current state between batches. The easiest way to do this is to make sure state_is_tupel is set to False for every RNN you use, and you will simply get a final state tensor back from the dynamic RNN function.

I personally don't like setting state_is_tupel to False since it is deprecated so I wrote my own code to flatten the state tupel. The following code is from my project to generate sound.

        batch_size = tf.shape(self.input_sound)[0]
        rnn = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.LSTMCell(self.hidden_size) for _ in range(self.n_hidden)])  
        zero_state = pack_state_tupel(rnn.zero_state(batch_size, tf.float32))
        self.input_state = tf.placeholder_with_default(zero_state, None)
        state = unpack_state_tupel(self.input_state, rnn.state_size)

        rnn_input_seq = tf.cond(self.is_training, lambda: self.input_sound[:, :-1], lambda: self.input_sound)
        output, final_state = tf.nn.dynamic_rnn(rnn, rnn_input_seq, initial_state = state)

        with tf.variable_scope('output_layer'):
            output = tf.reshape(output, (-1, self.hidden_size))
            W = tf.get_variable('W', (self.hidden_size, self.sample_length))
            b = tf.get_variable('b', (self.sample_length,))
            output = tf.matmul(output, W) + b
            output = tf.reshape(output, (batch_size, -1, self.sample_length))


        self.output_state = pack_state_tupel(final_state)
        self.output_sound = output

It uses the following two functions which should work for any type of RNN although I only tested it with this model.

def pack_state_tupel(state_tupel):
    if isinstance(state_tupel, tf.Tensor) or not hasattr(state_tupel, '__iter__'):
        return state_tupel
    else:
        return tf.concat(1, [pack_state_tupel(item) for item in state_tupel])

def unpack_state_tupel(state_tensor, sizes):
    def _unpack_state_tupel(state_tensor_, sizes_, offset_):
        if isinstance(sizes_, tf.Tensor) or not hasattr(sizes_, '__iter__'): 
            return tf.reshape(state_tensor_[:, offset_ : offset_ + sizes_], (-1, sizes_)), offset_ + sizes_
        else:
            result = []
            for size in sizes_:
                s, offset_ = _unpack_state_tupel(state_tensor_, size, offset_)
                result.append(s)
            if isinstance(sizes_, tf.nn.rnn_cell.LSTMStateTuple):
                return tf.nn.rnn_cell.LSTMStateTuple(*result), offset_
            else:
                return tuple(result), offset_
    return _unpack_state_tupel(state_tensor, sizes, 0)[0]

Finally in my generate function see how I manage the hidden state s.

def generate(self, seed, steps):
    def _step(x, s = None):
        feed_dict = {self.input_sound: np.reshape(x, (1, -1, self.sample_length))}
        if s is not None:
            feed_dict[self.input_state] = s
        return self.session.run([self.output_sound, self.output_state], feed_dict)

    seed_pad = self.sample_length - len(seed) % self.sample_length
    if seed_pad: seed = np.pad(seed, (seed_pad, 0), 'constant')

    y, s = _step(seed)
    y = y[:, -1:]

    result = [seed, y.flatten()]
    for _ in range(steps):
        y, s = _step(y, s)
        result.append(y.flatten())

    return np.concatenate(result) 

Upvotes: 4

Related Questions