Reputation: 785
I followed this tutorial https://www.tensorflow.org/tutorials/layers and trained a model for recognizing hand written numbers from the MNIST set.
The following code works as expected and prints for each image in the set the probability and class
mnist = tf.contrib.learn.datasets.load_dataset("mnist")
train_data = mnist.train.images # Returns np.array
tf.reset_default_graph()
with tf.Session() as sess:
mnist_classifier = tf.estimator.Estimator(model_fn=cnn_model_fn, model_dir="model/")
pred = mnist_classifier.predict(input_fn=tf.estimator.inputs.numpy_input_fn(
x={"x": train_data},
shuffle=False))
for p in pred:
print(p)
However, when I instead try to predict for just one image with
mnist_classifier.predict(input_fn=tf.estimator.inputs.numpy_input_fn(
x={"x": train_data[0]},
shuffle=False))
My program fails and TensorFlow reports
InvalidArgumentError: Input to reshape is a tensor with 128 values,
but the requested shape requires a multiple of 784
This puzzles me because when I print the length of the first image from the set it reports 784
print("length of input: {}".format(len(train_data[0]))
How do I get the predictions for just one image?
Upvotes: 2
Views: 2113
Reputation: 646
You can also use tf.expand_dims
. The documentation says:
This operation is useful if you want to add a batch dimension to a single element. For example, if you have a single image of shape [height, width, channels]
, you can make it a batch of one image with expand_dims(image, 0)
, which will make the shape [1, height, width, channels]
.
Upvotes: 0
Reputation: 10474
This is most likely related to the fact that you are dropping the batch dimension when creating the single-item dataset. What I mean by that is that you should use
mnist_classifier.predict(input_fn=tf.estimator.inputs.numpy_input_fn(
x={"x": np.array([train_data[0])]},
shuffle=False))
instead. Note the additional list wrapped around train_data[0]
. This will take the array of shape [1, 784] and create a dataset with one element, this in turn being a vector with 784 elements. As your code is right now, you are basically creating a dataset with 784 elements, each of which is a single number. This leads to shape mismatches down the road.
Upvotes: 1