Reputation: 31
I am new to TensorFlow. I am trying to run a pre-trained NN for number recognition 'wide_resnet_28_10' from github- https://github.com/Curt-Park/handwritten_digit_recognition. When I try to predict an image it says expected input to have 4D. This is what I tried-
from tensorflow.keras.models import load_model
import tensorflow as tf
import cv2
import numpy
model = load_model(r'C:\Users\sesha\Desktop\python\Deep learning NN\handwritten_digit_recognition-master\models\WideResNet28_10.h5')
image = cv2.imread(r'C:\Users\sesha\Desktop\python\Deep learning NN\test_org01.png')
img = tf.convert_to_tensor(image)
predictions = model.predict([img])
print(np.argmax(predictions))
most tutorial are vague, i did try np.reshape(1,X,X,-1) that didn't work.
Upvotes: 0
Views: 941
Reputation: 95
For the 4D input, it expects batches of data. You can make it a 4D tensor by doing:
predictions = model.predict(tf.expand_dims(img, 0))
if this does not work, try predict_on_batch instead of predict.
Also: I don't think that your image reading is correct. It will probably give you a tensor of the byte string.
This should work
path = tf.constant(img_path)
image = tf.io.read_file(path)
image = tf.io.decode_image(image)
image = tf.image.resize(image, (X, Y)) # if necessary
Upvotes: 3