Reputation: 11
I am recently making a project based on tensorflow CNN, MNIST dataset with a server interface.
At the predict part, I use tf.argmax() to get the largest logit, which will be the predicted value. However, the value it returns didn't seems like the correct answer.
The predict function is about like this:
self.img = tf.reshape(tf.image.convert_image_dtype(img, tf.float32), shape=[1, 28, 28, 1])
self._create_model()
saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state('../checkpoints/')
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
pred = tf.nn.softmax(self.logits)
prediction = tf.argmax(pred, 1)
logit = sess.run(pred)
result = sess.run(prediction)[0]
print(logit)
print(result)
return result
And the results are:
127.0.0.1 - - [19/Apr/2018 21:35:47] "POST /index.html HTTP/1.1" 200 -
[[ 0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]]
1
As you can see, the logits shows that the index with the maximum number is 5, but tf.argmax() gave me 1 instead.
By the way, my model is the basic MNIST CNN model as you can see in the link.
So what happened to this tf.argmax() function, or there's something wrong in my code?
Upvotes: 1
Views: 383
Reputation: 5722
Since your logit
(pred
) and result
(prediction[0]
) come from two different sess.run
, I'm wondering whether there are some differences between runs. For example, you have an iterator in the graph sending inputs to the model. With different runs, the iterator sends in different data leading to different predictions. It will be interesting to see what if you put pred
and prediction
in one same sess.run
like this:
logit, result = sess.run((pred, prediction))
print(logit)
print(result[0])
Upvotes: 1