mgxs
mgxs

Reputation: 23

How to replace some rows in 2-D tensors in Tensorflow

Here is an example of my question.

origin_embeddings = tf.constant([[1.1,2.2,3.3], 
                                 [4.4, 5.5, 6.6], 
                                 [7.7, 8.8, 9.9],
                                 [10.10, 11.11, 12.12]]) # 4 * 3

indice_updated_embeddings= tf.constant([[1.0, 2.0, 3.0], 
                                        [4.0, 5.0, 6.0]]) #  2 * 3

indices = tf.constant([1,3], dtype=tf.int32)

new_embeddings = tf.constant([[1.1, 2.2, 3.3],
                          [1.0, 2.0, 3.0], # rows=1, indice_updated_embeddings[0]
                          [7.7, 8.8, 9.9],
                          [4.0, 5.0, 6.0] # rows=3, indice_updated_embeddings[1]
                         ])

create a new embeddings matrix consisting of some rows (not in indices) of the origin embeddings, and some rows(in indices) of the updated embeddings.

In Numpy, it is simple to realize the function:

origin_embeddings[indices] = indice_updated_embeddings # in-place
new_embeddings = np.put(origin_embeddings, indices, indice_updated_embeddings) # create a new matrix

I want something like the 'put' function in numpy.

I have tried the tf.where method, but it is not that easy.

Upvotes: 2

Views: 962

Answers (1)

javidcf
javidcf

Reputation: 59691

If you are not using an old version of TensorFlow, you can do that with tf.tensor_scatter_nd_update:

import tensorflow as tf

origin_embeddings = tf.constant([[ 1.1 ,  2.2 ,  3.3 ], 
                                 [ 4.4 ,  5.5 ,  6.6 ], 
                                 [ 7.7 ,  8.8 ,  9.9 ],
                                 [10.10, 11.11, 12.12]]) # 4 * 3
indice_updated_embeddings= tf.constant([[1.0, 2.0, 3.0], 
                                        [4.0, 5.0, 6.0]]) #  2 * 3
indices = tf.constant([1, 3], dtype=tf.int32)

new_embeddings = tf.tensor_scatter_nd_update(origin_embeddings,
                                             tf.expand_dims(indices, 1),
                                             indice_updated_embeddings)
tf.print(new_embeddings)
# [[1.1 2.2 3.3]
#  [1 2 3]
#  [7.7 8.8 9.9]
#  [4 5 6]]

In older versions where tf.tensor_scatter_nd_update is not available, you can do it like this:

import tensorflow as tf

origin_embeddings = tf.constant([[1.1,2.2,3.3], 
                                 [4.4, 5.5, 6.6], 
                                 [7.7, 8.8, 9.9],
                                 [10.10, 11.11, 12.12]]) # 4 * 3
indice_updated_embeddings= tf.constant([[1.0, 2.0, 3.0], 
                                        [4.0, 5.0, 6.0]]) #  2 * 3
indices = tf.constant([1, 3], dtype=tf.int32)

indices2 = tf.expand_dims(indices, 1)
s = tf.shape(origin_embeddings)
mask = tf.scatter_nd(indices2, tf.ones_like(indices2, tf.bool), [s[0], 1])
updates = tf.scatter_nd(indices2, indice_updated_embeddings, s)
new_embeddings = tf.where(mask, updates, origin_embeddings)
tf.print(new_embeddings)
# [[1.1 2.2 3.3]
#  [1 2 3]
#  [7.7 8.8 9.9]
#  [4 5 6]]

Upvotes: 1

Related Questions