Alexander Soare
Alexander Soare

Reputation: 3257

Why does tf.constant give a dtype error if we pass in a tensor?

The following code

a = tf.range(10)
b = tf.constant(a, dtype=tf.float32)

gives the following error:

TypeError: Expected tensor with type tf.float32 not tf.int32

Although from the documentation, setting dtype means that tf.constant is supposed to cast a to the specified data type. So I don't see why this should give a type error.

I also know that:

a = np.arange(10)
b = tf.constant(a, dtype=tf.float32)

does not give an error.

So actually, I'm mainly wondering about what's happening under the hood here.

Upvotes: 0

Views: 473

Answers (1)

thushv89
thushv89

Reputation: 11333

If you look at the source here, you will see that EagerTensor gets a special treatment. Basically, if the dtype of an EagerTensor doesn't match the new dtype, an error is raised.

Here, tf.range() produces an EagerTensor. I'm not sure why the special treatment for EagerTensors though. Could be a performance related restriction.

Upvotes: 1

Related Questions