shapeare
shapeare

Reputation: 4233

Tensorflow: how to apply a function to the last dimension of a tensor

I have an array. I want to create a mask based on the values in the last dimension of this array. In Numpy, I could do:

import numpy as np
room = np.array([
    [[0, 0, 1], [1, 0, 0], [1, 0, 0]],
    [[1, 0, 0], [0, 0, 1], [1, 0, 0]],
    [[1, 0, 0], [1, 0, 0], [0, 0, 1]]
])
mask = np.apply_along_axis(lambda x: [1, 1, 1] if (x == [0, 0, 1]).all() else [0, 0, 0], axis=-1, arr=room)
result = mask * room
print(result)

In the above code, room is a (3, 3, 3) array based on which I created the mask. The mask created is also a (3, 3, 3) array and it will be used to multiply with other arrays to mask out unwanted elements.

But I have problem in achieving the same thing with Tensorflow. I have tried the following code,

room = tf.constant([
    [[0, 0, 1], [1, 0, 0], [1, 0, 0]],
    [[1, 0, 0], [0, 0, 1], [1, 0, 0]],
    [[1, 0, 0], [1, 0, 0], [0, 0, 1]]
])
room = tf.reshape(room, shape=(9, -1))
mask = tf.map_fn(lambda x: [1, 1, 1] if x == [0, 0, 1] else [0, 0, 0], room)

but it ended with the following error:

ValueError: The two structures don't have the same number of elements. First structure: <dtype: 'int32'>, second structure: [0, 0, 0].

Upvotes: 3

Views: 2212

Answers (1)

nessuno
nessuno

Reputation: 27070

map_fn has the parameter dtype that allows specifying the shape of the output if it's different from the shape of x.

However, this is not the problem.

You're mixing python conditions into a tensorflow operation: the python operations (like the if) are executed outside the graph while, instead, you want to define a graph that executes the desired op.

Let's dig into your problem:

  1. You want to unroll the room variable along the first dimension: map_fn is OK for that.
  2. You want to check if the current row (a vector of 3 elements) equals to [0, 0, 1].

To do this you need to use a tensorflow condition, tf.cond(pred, true_fn, false_fn).

Note that pred must be a scalar. Thus, let's check using tensorflow operations only, if the current row equals to your desired row and reduce the result to a single scalar.

It this is true just return the constant value [1,1,1], otherwise [0,0,0].

mask = tf.map_fn(lambda row: tf.cond(
    tf.equal(
        tf.reduce_prod(tf.cast(tf.equal(row, tf.constant([0,0,1])), tf.int32)), 1),
    lambda: tf.constant([1,1,1]),
    lambda: tf.constant([0,0,0])), room)

Upvotes: 2

Related Questions