Manuel Schmidt
Manuel Schmidt

Reputation: 2489

Creating `input_fn` from iterator

Most tutorials focus on the case where the entire training dataset fits into memory. However, I have an iterator which acts as an infinite stream of (features, labels)-tuples (creating them cheaply on the fly).

When implementing the input_fn for tensorflows estimator, can I return an instance from the iterator as

def input_fn():
   (feature_batch, label_batch) = next(it)
   return tf.constant(feature_batch), tf.constant(label_batch)

or does input_fn has to return the same (features, labels)-tuples on each call?

Moreover is this function called multiple times during training as I hope it is like in the following pseudocode:

for i in range(max_iter):
   learn_op(input_fn())

Upvotes: 6

Views: 4426

Answers (3)

user3722836
user3722836

Reputation: 141

from tensorflow.contrib.learn.python.learn.learn_io import generator_io
import numpy as np

# define generator
def generator():
    for index in range(2):
        yield {'a': np.ones(1) * index,'b': np.ones(1) * index + 32,'label': np.ones(1) * index - 32}

input_fn = generator_io.generator_input_fn(generator, target_key='label', batch_size=2, shuffle=False, num_epochs=1)
features, target = input_fn()

Refer to the test case https://github.com/tensorflow/tensorflow/pull/7045/files

Upvotes: 0

Manuel Schmidt
Manuel Schmidt

Reputation: 2489

I found a pull request which converts a generator to an input_fn: https://github.com/tensorflow/tensorflow/pull/7045/files

The relevant part is

  def _generator_input_fn():
    """generator input function."""
    queue = feeding_functions.enqueue_data(
      x,
      queue_capacity,
      shuffle=shuffle,
      num_threads=num_threads,
      enqueue_size=batch_size,
      num_epochs=num_epochs)

    features = (queue.dequeue_many(batch_size) if num_epochs is None
                else queue.dequeue_up_to(batch_size))
    if not isinstance(features, list):
      features = [features]
    features = dict(zip(input_keys, features))
    if target_key is not None:
      if len(target_key) > 1:
        target = {key: features.pop(key) for key in target_key}
      else:
        target = features.pop(target_key[0])
      return features, target
    return features
  return _generator_input_fn

Upvotes: 2

P-Gn
P-Gn

Reputation: 24591

The argument of input_fn are used throughout training but the function itself is called once. So creating a sophisticated input_fn that goes beyond returning a constant array as explained in the tutorial is not as straightforward.

Tensorflow proposes two examples of such non-trivial input_fn for numpy and panda arrays, but they start from an array in memory, so this does not help you with your problem.

You could also have a look at their code by following the links above, to see how they implement an efficient non-trivial input_fn, but you may find that it requires more code that you would like.

If you are willing to use the less-high level interface of Tensorflow, things are IMHO simpler and more flexible. There is a tutorial that covers most needs and the proposed solutions are easy(-er) to implement.

In particular, if you already have an iterator that returns data as you described in your question, using placeholders (section "Feeding" in the previous link) should be straightforward.

Upvotes: 3

Related Questions