Reputation: 4650
(Updated question) I think the original question is confusing, and I found a simpler way to ask this question.
#!/usr/bin/python
import tensorflow as tf
x = tf.Variable([2], tf.float32)
print x.dtype
If we try the above code segment, then the output is as follows:
<dtype: 'int32_ref'>
Since I explicitly specified the type of x
as tf.float32, I had thought that the type should be float32. However, it seems like the type is int32.
Could somebody answer this question?
(original question)
I tried the following code to replace one element of a 2-D tensorflow array.
#!/usr/bin/python
import tensorflow as tf
import numpy as np
ref = tf.Variable(np.arange(0, 12).reshape((4, 3)).astype(np.float64))
indices = tf.constant([[2, 2]])
updates = tf.Variable([8.0], tf.float64)
ref = tf.scatter_nd_update(ref, indices, updates)
with tf.Session() as sess:
init = tf.global_variables_initializer()
sess.run(init)
print sess.run(ref)
Strangely, I encountered the following type error:
TypeError: Input 'updates' of 'ScatterNdUpdate' Op has type float32 that does not match type float64 of argument 'ref'.
After changing tf.Variable([8.0], tf.float64)
to the following line, it worked.
updates = tf.Variable(np.array([8.0]).astype(np.float64), tf.float64)
So, it seems like the type of tf.Variable([8.0], tf.float64)
is not tf.float64, even though I explicitly specified the type as tf.float64. Could anyone tell me the reason? Thank you!
Upvotes: 1
Views: 1012
Reputation: 1802
Update:
Use dtype when using tf.Variable()
Original Question:
making ref type to float32 is working for me.
Do
ref = tf.Variable(np.arange(0, 12).reshape((4, 3)).astype(np.float32))
or
ref = tf.Variable(np.arange(0, 12).reshape((4, 3)),dtype=tf.float32)
I guess, ScatterNdUpdate operation works only for float32 not float64.
Upvotes: 0