Reputation: 1529
I am migrating from the older queue-based data pipeline to the newer tf.data
API. Suppose I have a code like the following, how can I explicitly set different batch sizes for my training and validation iterators.
filenames = tf.placeholder(tf.string, shape=[None])
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...) # Parse the record into tensors.
dataset = dataset.repeat() # Repeat the input indefinitely.
dataset = dataset.batch(32)
iterator = dataset.make_initializable_iterator()
# Initialize `iterator` with training data.
training_filenames = ["/var/data/file1.tfrecord",
"/var/data/file2.tfrecord"]
sess.run(iterator.initializer, feed_dict={filenames:
training_filenames})
# Initialize `iterator` with validation data.
validation_filenames = ["/var/data/validation1.tfrecord", ...]
sess.run(iterator.initializer, feed_dict={filenames:
validation_filenames})
EDIT:
Thank you. Based on the reply, my implementation is as follows: My implementation is like follows, but I'm not able to figure out why I'm getting this error:
import tensorflow as tf
def _parse(filename, label):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string)
image_resized = tf.image.resize_images(image_decoded, [224, 224])
image_resized.set_shape([224,224,3])
return image_resized, label
def input_pipeline(imglist,labellist, batch_size):
dataset = tf.data.Dataset.from_tensor_slices((imglist, labellist))
dataset = dataset.map(_parse) # Parse the record into tensors.
dataset = dataset.repeat() # Repeat the input indefinitely.
dataset = dataset.batch(batch_size)
return dataset
imglist = glob.glob('/var/temp/*.jpg')
train_imgs=imglist[0:100]
train_labels = [i for i in range(100)]
val_imgs=imglist[200:250]
val_labels = [i for i in range(50)]
training_batch_size = 4
validation_batch_size = 1
training_ds = input_pipeline(train_imgs, train_labels, training_batch_size)
validation_ds = input_pipeline(val_imgs, val_labels, validation_batch_size)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_ds.output_types, training_ds.output_shapes)
input_batch = iterator.get_next()
train_iter = training_ds.make_initializable_iterator()
val_iter = validation_ds.make_initializable_iterator()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Define training and validation handlers
training_handle = sess.run(train_iter.string_handle())
validation_handle = sess.run(val_iter.string_handle())
# Initialize training and validation dataset
sess.run(train_iter)
sess.run(val_iter)
# If we use training_handle, then input_batch tensor comes from training tfrecords
training_batch = sess.run(input_batch, feed_dict={handle: training_handle})
# If we use validation_handle, then input_batch tensor comes from validation tfrecords
validation_batch = sess.run(input_batch, feed_dict={handle: validation_handle})
But I end up getting the following error:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches, contraction_fn)
281 self._unique_fetches.append(ops.get_default_graph().as_graph_element(
--> 282 fetch, allow_tensor=True, allow_operation=True))
283 except TypeError as e:
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in as_graph_element(self, obj, allow_tensor, allow_operation)
3589 with self._lock:
-> 3590 return self._as_graph_element_locked(obj, allow_tensor, allow_operation)
3591
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _as_graph_element_locked(self, obj, allow_tensor, allow_operation)
3678 raise TypeError("Can not convert a %s into a %s." % (type(obj).__name__,
-> 3679 types_str))
3680
TypeError: Can not convert a Iterator into a Tensor or Operation.
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
<ipython-input-31-50c4f3464d03> in <module>()
47
48 # Initialize training and validation dataset
---> 49 sess.run(train_iter)
50 sess.run(val_iter)
51
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
898 try:
899 result = self._run(None, fetches, feed_dict, options_ptr,
--> 900 run_metadata_ptr)
901 if run_metadata:
902 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1118 # Create a fetch handler to take care of the structure of fetches.
1119 fetch_handler = _FetchHandler(
-> 1120 self._graph, fetches, feed_dict_tensor, feed_handles=feed_handles)
1121
1122 # Run request and get response.
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, graph, fetches, feeds, feed_handles)
425 """
426 with graph.as_default():
--> 427 self._fetch_mapper = _FetchMapper.for_fetch(fetches)
428 self._fetches = []
429 self._targets = []
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in for_fetch(fetch)
251 if isinstance(fetch, tensor_type):
252 fetches, contraction_fn = fetch_fn(fetch)
--> 253 return _ElementFetchMapper(fetches, contraction_fn)
254 # Did not find anything.
255 raise TypeError('Fetch argument %r has invalid type %r' % (fetch,
~/anaconda3/lib/python3.6/site-packages/tensorflow/python/client/session.py in __init__(self, fetches, contraction_fn)
284 raise TypeError('Fetch argument %r has invalid type %r, '
285 'must be a string or Tensor. (%s)' %
--> 286 (fetch, type(fetch), str(e)))
287 except ValueError as e:
288 raise ValueError('Fetch argument %r cannot be interpreted as a '
TypeError: Fetch argument <tensorflow.python.data.ops.iterator_ops.Iterator object at 0x7fa2c0697c88> has invalid type <class 'tensorflow.python.data.ops.iterator_ops.Iterator'>, must be a string or Tensor. (Can not convert a Iterator into a Tensor or Operation.)
Upvotes: 4
Views: 5682
Reputation: 636
I would create 2 tf.data.Dataset
, one for training and one for validation subsets. Once you have both datasets pipelines defined (where you are able to define 2 different batch sizes), you can join them in the graph by creating a single tf.data.Iterator
with a handler (in my case, the tf.placeholder
handle
).
import tensorflow as tf
def input_pipeline(filenames, batch_size):
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(...) # Parse the record into tensors.
dataset = dataset.repeat() # Repeat the input indefinitely.
dataset = dataset.batch(batch_size)
return dataset
training_filenames = ["/var/data/file1.tfrecord",
"/var/data/file2.tfrecord"]
training_batch_size = 32
validation_filenames = ["/var/data/validation1.tfrecord",
"/var/data/validation2.tfrecord"]
validation_batch_size = 16
training_ds = input_pipeline(training_filenames, training_batch_size)
validation_ds = input_pipeline(validation_filenames, validation_batch_size)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_ds.output_types, training_ds.output_shapes)
input_batch = iterator.get_next()
Before requesting batches from any of both datasets, you can get correponding handlers from each dataset using string_handle()
. After that, when you run input_batch
, you can decide if it comes from training or validation by defining it on the handle
placeholder.
train_iter = training_ds.make_initializable_iterator()
val_iter = validation_ds.make_initializable_iterator()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Define training and validation handlers
training_handle = sess.run(train_iter.string_handle())
validation_handle = sess.run(val_iter.string_handle())
# Initialize training and validation dataset
sess.run(train_iter.initializer)
sess.run(val_iter.initializer)
# If we use training_handle, then input_batch tensor comes from training tfrecords
trainaing_batch = sess.run(input_batch, feed_dict={handle: training_handle})
# If we use validation_handle, then input_batch tensor comes from validation tfrecords
validation_batch = sess.run(input_batch, feed_dict={handle: validation_handle})
Hope it helps!
EDIT:
Your current error seems to be due to trying to do a sess.run()
on a tf.data.Iterator
. Try to replace sess.run(train_iter)
for sess.run(train_iter.initializer)
(and same for validation iterator). train_iter.initializer
is the tf.Operation
that initializes train_iter
iterator. Everything should work now.
Upvotes: 7
Reputation: 1529
Slight modification needed to get the right answer:
import tensorflow as tf
imglist = glob.glob('/var/temp/*.jpg')
train_imgs=imglist[0:100]
train_labels = [i for i in range(100)]
val_imgs=imglist[200:250]
val_labels = [i for i in range(50)]
training_ds = tf.data.Dataset.from_tensor_slices((train_imgs,train_labels)).batch(4)
validation_ds = tf.data.Dataset.from_tensor_slices((val_imgs,val_labels)).batch(1)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_ds.output_types, training_ds.output_shapes)
input_batch = iterator.get_next()
train_iter = training_ds.make_initializable_iterator()
val_iter = validation_ds.make_initializable_iterator()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# Define training and validation handlers
training_handle = sess.run(train_iter.string_handle())
validation_handle = sess.run(val_iter.string_handle())
sess.run(train_iter.initializer)
# If we use training_handle, then input_batch tensor comes from training tfrecords
training_batch = sess.run(input_batch, feed_dict={handle: training_handle})
print("Training...")
print(training_batch)
sess.run(val_iter.initializer)
# If we use validation_handle, then input_batch tensor comes from validation tfrecords
print("Validation")
validation_batch = sess.run(input_batch, feed_dict={handle: validation_handle})
print(validation_batch)
Upvotes: 0