Anmol Jawandha
Anmol Jawandha

Reputation: 63

Calling the same batch tensorflow

I have a tensorflow graph that is reading from .tfrecords files, as described in the process here (taken from Tflow docs):

def read_my_file_format(filename_queue):
  reader = tf.SomeReader()
  key, record_string = reader.read(filename_queue)
  example, label = tf.some_decoder(record_string)
  processed_example = some_processing(example)
  return processed_example, label

def input_pipeline(filenames, batch_size, num_epochs=None):
  filename_queue = tf.train.string_input_producer(
      filenames, num_epochs=num_epochs, shuffle=True)
  example, label = read_my_file_format(filename_queue)
  # min_after_dequeue defines how big a buffer we will randomly sample
  #   from -- bigger means better shuffling but slower start up and more
  #   memory used.
  # capacity must be larger than min_after_dequeue and the amount larger
  #   determines the maximum we will prefetch.  Recommendation:
  #   min_after_dequeue + (num_threads + a small safety margin) * batch_size
  min_after_dequeue = 10000
  capacity = min_after_dequeue + 3 * batch_size
  example_batch, label_batch = tf.train.shuffle_batch(
      [example, label], batch_size=batch_size, capacity=capacity,
      min_after_dequeue=min_after_dequeue)
  return example_batch, label_batch`

In my code, a single batch (as returned by input_pipeline above) is used as an input to multiple networks (let's call them A,B) in my graph per iteration. So if I call:

#...define graph...
sess.run([A,B])

does tensorflow guarantee that it will use the same batch for each sess.run call?

Upvotes: 0

Views: 241

Answers (1)

eaksan
eaksan

Reputation: 575

If input of model A and B is example_batch and you evaluate the models simultaneously (as in your example sess.run([A,B])) then I expect to see the same batch. Because both models are fed by the same dequeuing operation. As soon as you break the synchronization (i.e., running separately) then inputs will be different.

The following code snippet looks trivial but shows my point.

import tensorflow as tf
import numpy as np
import time

batch_size = 16
input_shape, target_shape = (128), () # input with dimensionality 128.
num_threads = 4 # for input pipeline
queue_capacity = 10 # for input pipeline


def get_random_data_sample():
    # Random inputs and targets
    np_input = np.float32(np.random.normal(0,1, input_shape))
    np_target = np.int32(1)

    # Sleep randomly between 1 and 3 seconds.
    #time.sleep(np.random.randint(1,3,1)[0])

    return np_input, np_target

tensorflow_input, tensorflow_target = tf.py_func(get_random_data_sample, [], [tf.float32, tf.int32])

def create_model(inputs, hidden_size, num_hidden_layers):
    # Create a dummy dense network.
    dense_layer = inputs
    for i in range(num_hidden_layers):
        dense_layer = tf.layers.dense(
                inputs=dense_layer,
                units=hidden_size,
                kernel_initializer= tf.zeros_initializer(),
                bias_initializer= tf.zeros_initializer(),
                activation=None,
                use_bias=True,
                reuse=False)
    return dense_layer, inputs

# input pipeline
batch_inputs, batch_targets = tf.train.batch([tensorflow_input, tensorflow_target], 
                                             batch_size=batch_size, 
                                             num_threads=num_threads, 
                                             shapes=[input_shape, target_shape], 
                                             capacity=queue_capacity)

# Different models A and B using the same input operation.
modelA, modelA_inputs = create_model(batch_inputs, 32, 1) # 1 hidden layer
modelB, modelB_inputs = create_model(batch_inputs, 64, 2) # 2 hidden layers

sess = tf.InteractiveSession()
tf.train.start_queue_runners()
sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()))

sess = tf.InteractiveSession()
tf.train.start_queue_runners()
sess.run(tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()))

# (1) Evaluate the models simultaneously.
resultA, resultB, inputsA, inputsB = sess.run([modelA, modelB, modelA_inputs, modelB_inputs])
assert((inputsA == inputsB).all())

# (2) Evaluate the models separately.
resultA2, inputsA2 = sess.run([modelA, modelA_inputs])
resultB2, inputsB2 = sess.run([modelB, modelB_inputs])
assert((inputsA2 == inputsB2).all())

Naturally the second evaluation uses different input batches and fails assertion. I hope this helps.

Upvotes: 1

Related Questions