phoenixwing
phoenixwing

Reputation: 337

What is the TensorFlow equivalent of this numpy array permutation?

I need to permute elements of a tensor in TF according to a given indexing. From 2 arrays a and b(indices), I need to compute a new array that permutes the elements in a according to the indices in b. For indices that are empty, it should fill with NA (or equivalent).

For example,

a = [10, 20, 30]  
b = [-1,  0,  3]  
output = [ 10,  20,   NA,   NA,  30] 

I need to code the equivalent of what happens to the following numpy arrays but for TF tensors.

a = np.array([10,20,30])
b = np.array([-1,0,3])
mini = abs(np.min(b))
maxi = abs(np.max(b))
output = np.zeros(maxi+mini+1)
for ai,bi in zip(a,b):
    output[bi+mini]= ai

How do I do this with TensorFlow tensors?

Upvotes: 2

Views: 2018

Answers (2)

phoenixwing
phoenixwing

Reputation: 337

I found a way of achieving this, I'm posting my answer here in case it helps anyone else.
The scatter_nd function in TensorFlow is very handy in this situation. The following code permutes elements in the input tensor I according to the transformation given in tensor T. scatter_nd is used to create the new tensor according to this permutation.

sess = tf.InteractiveSession()
I = tf.constant([10,20,30])
T = tf.constant([-1,0,3])
T = T - tf.reduce_min(T)
T_shape = int(T.get_shape()[0])
T = tf.reshape(T, [T_shape,1])
O_shape = tf.reduce_max(T)+1
O = tf.scatter_nd(T, I, [O_shape])
print(sess.run([I,T,O]))
sess.close()

This code performs the following task:
Given

Input = [10, 20, 30]
Transformation = [-1, 0, 3]

Computes

Output = [10, 20,  0,  0, 30]

Upvotes: 2

Geoffrey Irving
Geoffrey Irving

Reputation: 6613

African or European?

  1. If you know that the indices are strictly increasing, tf.sparse_to_dense does what you want.

  2. If the indices are distinct but in increasing order, you can use tf.sparse_reorder to fix the order and then use tf.sparse_tensor_to_dense.

  3. If there are duplicates and you want matching values to add, use tf.unsorted_segment_sum.

  4. If there are duplicates and you want the last entry to win (corresponding exactly to your Python loop), use tf.dynamic_stitch.

Apologies for the zoo of options. The ops were all added for different reasons, so the overall design is not particularly clean.

Upvotes: 2

Related Questions