Reputation: 33
I'm trying to visualize the output of a convolutional autoencoder using TensorFlow Estimator API
input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": train_data},
y=None,
batch_size=8,
num_epochs=None,
shuffle=True)
autoencoder = tf.estimator.Estimator(model_fn=autoencoder_model_fn, model_dir=model_dir)
tensors_to_log = {"loss": "loss"}
logging_hook = tf.train.LoggingTensorHook(tensors=tensors_to_log, every_n_iter=1000)
autoencoder.train(
input_fn=input_fn,
steps=50000,
hooks=[logging_hook])
input_fn_predict = tf.estimator.inputs.numpy_input_fn(
x={"x": train_data},
y=None,
batch_size=1,
num_epochs=None,
shuffle=False)
predictions = autoencoder.predict(input_fn=input_fn_predict)
predictions = [p['decoded_image'] for p in predictions]
print predictions[0].shape
I get the following error :
Traceback (most recent call last):
File "AutoEncoder.py", line 164, in <module>
main()
File "AutoEncoder.py", line 157, in main
predictions = [p['decoded_image'] for p in predictions]
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 425, in predict
for i in range(self._extract_batch_length(preds_evaluated)):
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/estimator/estimator.py", line 592, in _extract_batch_length
'different batch length then others.' % key)
ValueError: Batch length of predictions should be same. features has different batch length then others.
Can anybody see what I did wrong? As I understand it, my batch size during prediction is a constant, equal to 1... Thanks in advance!
Upvotes: 2
Views: 1522
Reputation: 41
I succeed in dealing this issue. I just updated tensorflow (I had the 1.6 version and you need the 1.7 to have different batch size)
To see your tensorflow version you just need in your batch :
>python
>>>import tensorflow as tf
>>>print(tf.__version__)
Then, when you have the 1.7 version (or more), you use inside predict the argument :
yield_single_examples=False
(It is by default on True, and you will have the same mistake).
In your code, it will be :
predictions = autoencoder.predict(input_fn=input_fn_predict, yield_single_examples=False)
Upvotes: 2