desert_ranger
desert_ranger

Reputation: 1737

InvalidArgumentError: Input 1 of node while_1/Merge_1 was passed float from while_1/NextIteration_1:0

I have a TensorArray (a) to store the values computed within the tf.while_loop. However, I cannot convert the TensorArray to a Numpy array. For some reason, there seems to be a mismatch between int32 and float32.

import time
import tensorflow as tf
import numpy as np


#Importing a generic dataset from Keras
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(
    path='mnist.npz'
)


x_batch = tf.convert_to_tensor(x_train)
s_pred_im = tf.convert_to_tensor(x_batch)
iters = tf.constant(10)
a = tf.TensorArray(tf.float32, size=10)

def cond(value, a, s_pred_im, x_batch, i, iters):
    return tf.less(i, iters)

def body(value, a, s_pred_im, x_batch, i, iters):
    value = tf.math.reduce_sum(tf.image.ssim(s_pred_im, x_batch, max_val=255, filter_size = 28))
    a = a.write(i,value)
    return [value, a, s_pred_im, x_batch, tf.add(i,1), iters]

res = tf.while_loop(cond, body, [0, a, s_pred_im, x_batch, 0, iters])

b = res[1].stack()

with tf.Session() as sess:
    b.eval()

Doing this gives the following error -

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)

~\anaconda3\envs\test\lib\site-packages\tensorflow_core\python\client\session.py in _do_call(self, fn, *args)
   1364     try:
-> 1365       return fn(*args)
   1366     except errors.OpError as e:

~\anaconda3\envs\test\lib\site-packages\tensorflow_core\python\client\session.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata)
   1347       # Ensure any changes to the graph are reflected in the runtime.
-> 1348       self._extend_graph()
   1349       return self._call_tf_sessionrun(options, feed_dict, fetch_list,

~\anaconda3\envs\test\lib\site-packages\tensorflow_core\python\client\session.py in _extend_graph(self)
   1387     with self._graph._session_run_lock():  # pylint: disable=protected-access
-> 1388       tf_session.ExtendSession(self._session)
   1389 

InvalidArgumentError: Input 1 of node while_1/Merge_1 was passed float from while_1/NextIteration_1:0 incompatible with expected int32.

During handling of the above exception, another exception occurred:

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-72-5642d29d3bf6> in <module>
      1 with tf.Session() as sess:
----> 2     b.eval()

~\anaconda3\envs\test\lib\site-packages\tensorflow_core\python\framework\ops.py in eval(self, feed_dict, session)
    796 
    797     """
--> 798     return _eval_using_default_session(self, feed_dict, self.graph, session)
    799 
    800   def experimental_ref(self):

~\anaconda3\envs\test\lib\site-packages\tensorflow_core\python\framework\ops.py in _eval_using_default_session(tensors, feed_dict, graph, session)
   5405                        "the tensor's graph is different from the session's "
   5406                        "graph.")
-> 5407   return session.run(tensors, feed_dict)
   5408 
   5409 

~\anaconda3\envs\test\lib\site-packages\tensorflow_core\python\client\session.py in run(self, fetches, feed_dict, options, run_metadata)
    954     try:
    955       result = self._run(None, fetches, feed_dict, options_ptr,
--> 956                          run_metadata_ptr)
    957       if run_metadata:
    958         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

~\anaconda3\envs\test\lib\site-packages\tensorflow_core\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1178     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1179       results = self._do_run(handle, final_targets, final_fetches,
-> 1180                              feed_dict_tensor, options, run_metadata)
   1181     else:
   1182       results = []

~\anaconda3\envs\test\lib\site-packages\tensorflow_core\python\client\session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1357     if handle is None:
   1358       return self._do_call(_run_fn, feeds, fetches, targets, options,
-> 1359                            run_metadata)
   1360     else:
   1361       return self._do_call(_prun_fn, handle, feeds, fetches)

~\anaconda3\envs\test\lib\site-packages\tensorflow_core\python\client\session.py in _do_call(self, fn, *args)
   1382                     '\nsession_config.graph_options.rewrite_options.'
   1383                     'disable_meta_optimizer = True')
-> 1384       raise type(e)(node_def, op, message)
   1385 
   1386   def _extend_graph(self):

InvalidArgumentError: Input 1 of node while_1/Merge_1 was passed float from while_1/NextIteration_1:0 incompatible with expected int32.

PS: This is an edit from an earlier post wherein I was trying to evaluate the value of the tensorarray incorrectly.

Upvotes: 1

Views: 264

Answers (1)

user11530462
user11530462

Reputation:

This sounds like datatype mismatch in tf.while_loop. Take a look at working code below.

import time
import tensorflow as tf
import numpy as np
#tf.compat.v1.disable_eager_execution()


#Importing a generic dataset from Keras
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data(
    path='mnist.npz'
)


x_batch = tf.convert_to_tensor(x_train)
print(type(x_batch))
s_pred_im = tf.convert_to_tensor(x_batch)
print(type(s_pred_im))
iters = tf.constant(10)
print(type(iters))
a = tf.TensorArray(tf.float32, size=10)
print(type(a))

def cond(value, a, s_pred_im, x_batch, i, iters):
    return tf.less(i, iters)

def body(value, a, s_pred_im, x_batch, i, iters):
    value = tf.math.reduce_sum(tf.image.ssim(s_pred_im, x_batch, max_val=255, filter_size = 28))
    a = a.write(i,value)
    return [value, a, s_pred_im, x_batch, tf.add(i,1), iters]

res = tf.while_loop(cond, body, [0.0, a, s_pred_im, x_batch, 0, iters])

b = res[1].stack()

sess = tf.compat.v1.Session()
with sess.as_default():
  print(b.eval())

Output

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]

Upvotes: 1

Related Questions