user9799714
user9799714

Reputation: 75

How to convert tensor dtype=tf.float32_ref to dtype=tf.float32?

I want to use the modify word_embeddings dtype from float32_ref to float32 through the function tf.cast():

   word_embeddings_modify=tf.cast(word_embeddings,dtype=tf.float32)

But it did not work as expected and word_embeddings_modify dtype still tf.float32_ref.

   word_embeddings = tf.scatter_nd_update(var_output, error_word_f,sum_all)
   word_embeddings_modify=tf.cast(word_embeddings,dtype=tf.float32)
   word_embeddings_dropout = tf.nn.dropout(word_embeddings_2, dropout_pl)

Upvotes: 4

Views: 1381

Answers (1)

nessuno
nessuno

Reputation: 27042

You can dereference a _ref type using tf.identity

word_embeddings = tf.identity(word_embeddings)

Upvotes: 2

Related Questions