chanwcom
chanwcom

Reputation: 4650

Why is the type of tf.Variable([8.0], tf.float64) float 32 rather than float 64 in TensorFlow?

(Updated question) I think the original question is confusing, and I found a simpler way to ask this question.

#!/usr/bin/python

import tensorflow as tf
x = tf.Variable([2], tf.float32)
print x.dtype

If we try the above code segment, then the output is as follows:

<dtype: 'int32_ref'>

Since I explicitly specified the type of x as tf.float32, I had thought that the type should be float32. However, it seems like the type is int32.

Could somebody answer this question?


(original question)

I tried the following code to replace one element of a 2-D tensorflow array.

#!/usr/bin/python

import tensorflow as tf
import numpy as np

ref = tf.Variable(np.arange(0, 12).reshape((4, 3)).astype(np.float64))
indices = tf.constant([[2, 2]])
updates = tf.Variable([8.0], tf.float64)
ref = tf.scatter_nd_update(ref, indices, updates)

with tf.Session() as sess:
  init = tf.global_variables_initializer()
  sess.run(init)
  print sess.run(ref)

Strangely, I encountered the following type error:

TypeError: Input 'updates' of 'ScatterNdUpdate' Op has type float32 that does not match type float64 of argument 'ref'.

After changing tf.Variable([8.0], tf.float64) to the following line, it worked.

updates = tf.Variable(np.array([8.0]).astype(np.float64), tf.float64)

So, it seems like the type of tf.Variable([8.0], tf.float64) is not tf.float64, even though I explicitly specified the type as tf.float64. Could anyone tell me the reason? Thank you!

Upvotes: 1

Views: 1012

Answers (2)

Harsha Pokkalla
Harsha Pokkalla

Reputation: 1802

Update:

Use dtype when using tf.Variable()

Original Question:

making ref type to float32 is working for me.

Do

ref = tf.Variable(np.arange(0, 12).reshape((4, 3)).astype(np.float32))

or

ref = tf.Variable(np.arange(0, 12).reshape((4, 3)),dtype=tf.float32)

I guess, ScatterNdUpdate operation works only for float32 not float64.

Upvotes: 0

pfm
pfm

Reputation: 6328

The reason is really simple: your code create a tf.Variable which is trainable (your tf.float64 is interpreted as a True for the trainable argument. If you just add dtype, it will work:

    updates = tf.Variable([8.0], dtype=tf.float64)

Actually, there was a similar Q&A.

Upvotes: 3

Related Questions