Reputation: 3021
Suppose I have a tensor like this
[
[ 6 -2 -2 -2 -1 -2 -3 -3 -6 -6]
[ 1 -6 -7 -7 -7 -7 -7 -6 -6 -6]
[ 5 -3 -3 -4 -4 -4 -4 -3 -3 -3]
]
I have to perform the following operation on each row. If first element is the largest(value) element in the row but its value is less than 4, then swap first and second elements of the row. Resulting tensor will be
[
[ 6 -2 -2 -2 -1 -2 -3 -3 -6 -6]
[ -6 1 -7 -7 -7 -7 -7 -6 -6 -6] #elements swapped
[ 5 -3 -3 -4 -4 -4 -4 -3 -3 -3]
]
I am working in python using tensorflow module. Please help.
Upvotes: 0
Views: 670
Reputation: 126184
The general recipe for a problem like this is to use tf.map_fn()
to create a new tensor with the appropriate value by applying a function to each row. Let's start by how to express the condition for a single row:
row = tf.placeholder(tf.int32, shape=[10])
condition = tf.logical_and(
tf.equal(row[0], tf.reduce_max(row)),
tf.less(row[0], 4))
sess = tf.Session()
print sess.run(condition, feed_dict={row: [6, -2, -2, -2, -1, -2, -3, -3, -6, -6]})
print sess.run(condition, feed_dict={row: [1, -6, -7, -7, -7, -7, -7, -6, -6, -6]})
print sess.run(condition, feed_dict={row: [5, -3, -3, -4, -4, -4, -4, -3, -3, -3]})
# Prints the following:
# False
# True
# False
Now the we have a condition, we can use tf.cond()
to build a conditional expression that swaps the first two elements if the condition is true:
def swap_first_two(x):
swapped_first_two = tf.stack([x[1], x[0]])
rest = x[2:]
return tf.concat([swapped_first_two, rest], 0)
maybe_swapped = tf.cond(condition, lambda: swap_first_two(row), lambda: row)
print sess.run(maybe_swapped, feed_dict={row: [6, -2, -2, -2, -1, -2, -3, -3, -6, -6]})
print sess.run(maybe_swapped, feed_dict={row: [1, -6, -7, -7, -7, -7, -7, -6, -6, -6]})
print sess.run(maybe_swapped, feed_dict={row: [5, -3, -3, -4, -4, -4, -4, -3, -3, -3]})
# Prints the following:
# [ 6 -2 -2 -2 -1 -2 -3 -3 -6 -6]
# [-6 1 -7 -7 -7 -7 -7 -6 -6 -6]
# [ 5 -3 -3 -4 -4 -4 -4 -3 -3 -3]
Finally, we put it all together by wrapping the computation of maybe_swapped
in a function, and passing it to tf.map_fn()
:
matrix = tf.constant([
[6, -2, -2, -2, -1, -2, -3, -3, -6, -6],
[1, -6, -7, -7, -7, -7, -7, -6, -6, -6],
[5, -3, -3, -4, -4, -4, -4, -3, -3, -3],
])
def row_function(row):
condition = tf.logical_and(
tf.equal(row[0], tf.reduce_max(row)),
tf.less(row[0], 4))
def swap_first_two(x):
swapped_first_two = tf.stack([x[1], x[0]])
rest = x[2:]
return tf.concat([swapped_first_two, rest], 0)
maybe_swapped = tf.cond(condition, lambda: swap_first_two(row), lambda: row)
return maybe_swapped
result = tf.map_fn(row_function, matrix)
print sess.run(result)
# Prints the following:
# [[ 6 -2 -2 -2 -1 -2 -3 -3 -6 -6]
# [-6 1 -7 -7 -7 -7 -7 -6 -6 -6]
# [ 5 -3 -3 -4 -4 -4 -4 -3 -3 -3]]
Upvotes: 6