DataOrc
DataOrc

Reputation: 839

Why does a Keras model object's predict method not allow batch size of 1?

I've trained a fine-tuned ELMo model using Keras that will only predict with a batch_size of 2. Here's some example code:

model_input = np.repeat(np.array([str(user_input)]), 2)
model.predict(model_input, batch_size=2)

This code runs perfectly fine. However, if I run this:

model_input = np.array([str(user_input)])
model.predict(model_input, batch_size=1)

I get this error:

Traceback (most recent call last):
  File "nlu/nlu_classifiers/elmo_scratch.py", line 67, in <module>
    main()
  File "nlu/nlu_classifiers/elmo_scratch.py", line 61, in main
    model.predict(model_input, batch_size=1)
  File "/Users/mjs/anaconda3/envs/nlucp36/lib/python3.6/site-packages/keras/engine/training.py", line 1169, in predict
    steps=steps)
  File "/Users/mjs/anaconda3/envs/nlucp36/lib/python3.6/site-packages/keras/engine/training_arrays.py", line 294, in predict_loop
    batch_outs = f(ins_batch)
  File "/Users/mjs/anaconda3/envs/nlucp36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2715, in __call__
    return self._call(inputs)
  File "/Users/mjs/anaconda3/envs/nlucp36/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py", line 2675, in _call
    fetched = self._callable_fn(*array_vals)
  File "/Users/mjs/anaconda3/envs/nlucp36/lib/python3.6/site-packages/tensorflow/python/client/session.py", line 1439, in __call__
    run_metadata_ptr)
  File "/Users/mjs/anaconda3/envs/nlucp36/lib/python3.6/site-packages/tensorflow/python/framework/errors_impl.py", line 528, in __exit__
    c_api.TF_GetCode(self.status.status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: input must be a vector, got shape: []
         [[{{node lambda_1/module_apply_default/StringSplit}}]]

Why is this? And is there a way to predict on a single example without having to use np.repeat? It's not a huge problem because it's basically the same speed, but it's been annoying me for a little while.

Upvotes: 1

Views: 386

Answers (1)

rob
rob

Reputation: 18513

np.repeat() wraps np.array([str(user_input)]) in an array but you are not calling np.repeat() when your batch_size is 1 so model_input is a 1D array instead of a 2D array. Try this:

model_input = np.array([[str(user_input)]])
model.predict(model_input, batch_size=1)

Upvotes: 1

Related Questions