Sreeragh A R
Sreeragh A R

Reputation: 3021

Modify a tensor

Suppose I have a tensor like this

[
    [ 6 -2 -2 -2 -1 -2  -3 -3 -6 -6]
    [ 1 -6 -7 -7 -7 -7  -7 -6 -6 -6]
    [ 5 -3 -3 -4 -4 -4  -4 -3 -3 -3]
 ]

I have to perform the following operation on each row. If first element is the largest(value) element in the row but its value is less than 4, then swap first and second elements of the row. Resulting tensor will be

[
    [ 6 -2 -2 -2 -1 -2  -3 -3 -6 -6]
    [ -6 1 -7 -7 -7 -7  -7 -6 -6 -6] #elements swapped
    [ 5 -3 -3 -4 -4 -4  -4 -3 -3 -3]
 ]

I am working in python using tensorflow module. Please help.

Upvotes: 0

Views: 670

Answers (1)

mrry
mrry

Reputation: 126184

The general recipe for a problem like this is to use tf.map_fn() to create a new tensor with the appropriate value by applying a function to each row. Let's start by how to express the condition for a single row:

row = tf.placeholder(tf.int32, shape=[10])

condition = tf.logical_and(
    tf.equal(row[0], tf.reduce_max(row)),
    tf.less(row[0], 4))

sess = tf.Session()

print sess.run(condition, feed_dict={row: [6, -2, -2, -2, -1, -2, -3, -3, -6, -6]}) 
print sess.run(condition, feed_dict={row: [1, -6, -7, -7, -7, -7, -7, -6, -6, -6]})
print sess.run(condition, feed_dict={row: [5, -3, -3, -4, -4, -4, -4, -3, -3, -3]})

# Prints the following:
# False
# True
# False

Now the we have a condition, we can use tf.cond() to build a conditional expression that swaps the first two elements if the condition is true:

def swap_first_two(x):
  swapped_first_two = tf.stack([x[1], x[0]])
  rest = x[2:]
  return tf.concat([swapped_first_two, rest], 0)

maybe_swapped = tf.cond(condition, lambda: swap_first_two(row), lambda: row)

print sess.run(maybe_swapped, feed_dict={row: [6, -2, -2, -2, -1, -2, -3, -3, -6, -6]}) 
print sess.run(maybe_swapped, feed_dict={row: [1, -6, -7, -7, -7, -7, -7, -6, -6, -6]})
print sess.run(maybe_swapped, feed_dict={row: [5, -3, -3, -4, -4, -4, -4, -3, -3, -3]})

# Prints the following:
# [ 6 -2 -2 -2 -1 -2 -3 -3 -6 -6]
# [-6  1 -7 -7 -7 -7 -7 -6 -6 -6]
# [ 5 -3 -3 -4 -4 -4 -4 -3 -3 -3]

Finally, we put it all together by wrapping the computation of maybe_swapped in a function, and passing it to tf.map_fn():

matrix = tf.constant([
    [6, -2, -2, -2, -1, -2, -3, -3, -6, -6],
    [1, -6, -7, -7, -7, -7, -7, -6, -6, -6],
    [5, -3, -3, -4, -4, -4, -4, -3, -3, -3],
])

def row_function(row):
  condition = tf.logical_and(
      tf.equal(row[0], tf.reduce_max(row)),
      tf.less(row[0], 4))

  def swap_first_two(x):
    swapped_first_two = tf.stack([x[1], x[0]])
    rest = x[2:]
    return tf.concat([swapped_first_two, rest], 0)

  maybe_swapped = tf.cond(condition, lambda: swap_first_two(row), lambda: row)

  return maybe_swapped

result = tf.map_fn(row_function, matrix)

print sess.run(result)

# Prints the following:
# [[ 6 -2 -2 -2 -1 -2 -3 -3 -6 -6]
#  [-6  1 -7 -7 -7 -7 -7 -6 -6 -6]
#  [ 5 -3 -3 -4 -4 -4 -4 -3 -3 -3]]

Upvotes: 6

Related Questions