Mateusz Kalinowski
Mateusz Kalinowski

Reputation: 33

tensorflow dataset tf.estimator.inputs.numpy_input_fn

I'm writing a code for reading images and labels from disc in tensorflow and then trying to call tf.estimator.inputs.numpy_input_fn. How can I pass the whole dataset instead of single image. My code looks like:

filenames = tf.constant(filenames)
labels = tf.constant(labels)

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(_parse_function)
dataset_batched = dataset.batch(10)
iterator = dataset_batched.make_one_shot_iterator()
features, labels = iterator.get_next()

with tf.Session() as sess:

  print(dataset_batched)
  print(np.shape(sess.run(features)))
  print(np.shape(sess.run(labels)))

  mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_mk, model_dir=dir)
  train_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": np.array(sess.run(features))},
                                                  y=np.array(sess.run(labels)),
                                                  batch_size=1,
                                                  num_epochs=None,
                                                  shuffle=False)
  mnist_classifier.train(input_fn=train_input_fn, steps=1)

And my question is how can I pass dataset here x={"x": np.array(sess.run(features))}

Upvotes: 1

Views: 2948

Answers (1)

xdurch0
xdurch0

Reputation: 10475

There is no need/use for numpy_input_fn here. You should wrap the code at the top into a function (say, my_input_fn) that returns iterator.get_next() and, then pass input_fn=my_input_fn into the train call. This would pass the full dataset to the training code in batches of 10.

numpy_input_fn is for when you have the full dataset available in an array already and want a quick way to do batching/shuffling/repeating etc.

Upvotes: 7

Related Questions