Reputation: 2074
How to check if a tf.Tensor is mutable?
I want to assert the arguments of a function have the correct types.
A tf.Tensor can be mutable:
import tensorflow as tf
import numpy as np
x = tf.get_variable('x', shape=(2,), dtype=np.float32)
print(x[1]) # x[1] is a tf.Tensor
tf.assign(x[1], 1.0)
Upvotes: 2
Views: 1787
Reputation: 631
You can check their dtype attributes e.g. assert my_tensor.dtype == tf.float32
.
Tensors are immutable outside of variables: they describe relationships between quantities. Data types will not change unless a type cast operation is added to the graph, adding an edge. If a value is passed to a tensor with a type that is different to the expected type, e.g. when loading data into a pipeline, an error is raised. You can check this by assigning a tensor with an incorrect type -- you will get an error.
Try this code
import tensorflow as tf
x = tf.get_variable('x', shape=(2,), dtype=tf.float32)
tf.assign(x[1], tf.ones(shape=(2,), dtype=tf.int32))
You should get an error to the effect of "TypeError: Input 'value' of 'StridedSliceAssign' Op has type int32 that does not match type float32 of argument 'ref'."
Upvotes: 0
Reputation: 59711
This is not part of the public API, but looking at how tf.assign
is implemented, I think you can just do:
import tensorflow as tf
def is_assignable(x):
return x.dtype._is_ref_dtype or (isinstance(x, tf.Tensor) and hasattr(x, 'assign'))
Upvotes: 2