Reputation: 601
I want to use scatter_nd_update
to change the content of the tensor returned from tf.nn.embedding_lookup()
. However, the returned tensor is not mutable, and the scatter_nd_update()
require an mutable tensor as input.
I spent a lot of time trying to find a solution, including using gen_state_ops._temporary_variable
and using tf.sparse_to_dense
, unfortunately all failed.
I wonder is there a beautiful solution toward it?
with tf.device('/cpu:0'), tf.name_scope("embedding"):
self.W = tf.Variable(
tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
name="W")
self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)
updates = tf.constant(0,shape=[embedding_size])
for i in range(1,sequence_length - 2):
indices = [None,i]
tf.scatter_nd_update(self.embedded_chars,indices,updates)
self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1)
Upvotes: 1
Views: 834
Reputation: 601
This problem rooted from not clearly understand the tensor and variable in the tensorflow context. Later with more knowledge of the tensor, the solution came to my mind is:
with tf.device('/cpu:0'), tf.name_scope("embedding"):
self.W = tf.Variable(
tf.random_uniform([vocab_size, embedding_size], -1.0, 1.0),
name="W")
self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)
for i in range(0,sequence_length - 1,2):
self.tslice = tf.slice(self.embedded_chars,[0,i,0],[0,1,128])
self.tslice2 = tf.slice(self.embedded_chars,[0,i+1,0],[0,1,128])
self.tslice3 = tf.slice(self.embedded_chars,[0,i+2,0],[0,1,128])
self.toffset1 = tf.subtract(self.tslice,self.tslice2)
self.toffset2 = tf.subtract(self.tslice2,self.tslice3)
self.tconcat = tf.concat([self.toffset1,self.toffset2],1)
self.embedded_chars_expanded = tf.expand_dims(self.embedded_chars, -1)
the function used, tf.slice, tf.subtract, tf.concat all accept tensor as input. Just avoid using function like tf.scatter_nd_update that require variable as input.
Upvotes: 0
Reputation: 53758
tf.nn.embedding_lookup
simply returns the slice of the larger matrix, so the simplest solution is to update the value of that matrix itself, in your case it's self.W
:
self.embedded_chars = tf.nn.embedding_lookup(self.W, self.input_x)
Since it's a variable, it is compliant with tf.scatter_nd_update
. Note that you can't update just any tensor, only variables.
Another option is to create a new variable just for the selected slice, assign self.embedded_chars
to it and perform an update afterwards.
Caveat: in both cases, you're blocking the gradients to train the embedding matrix, so double check that overwriting the learned value is really what you want.
Upvotes: 1