Reputation: 4233
I have an array. I want to create a mask based on the values in the last dimension of this array. In Numpy, I could do:
import numpy as np
room = np.array([
[[0, 0, 1], [1, 0, 0], [1, 0, 0]],
[[1, 0, 0], [0, 0, 1], [1, 0, 0]],
[[1, 0, 0], [1, 0, 0], [0, 0, 1]]
])
mask = np.apply_along_axis(lambda x: [1, 1, 1] if (x == [0, 0, 1]).all() else [0, 0, 0], axis=-1, arr=room)
result = mask * room
print(result)
In the above code, room
is a (3, 3, 3) array based on which I created the mask. The mask created is also a (3, 3, 3) array and it will be used to multiply with other arrays to mask out unwanted elements.
But I have problem in achieving the same thing with Tensorflow. I have tried the following code,
room = tf.constant([
[[0, 0, 1], [1, 0, 0], [1, 0, 0]],
[[1, 0, 0], [0, 0, 1], [1, 0, 0]],
[[1, 0, 0], [1, 0, 0], [0, 0, 1]]
])
room = tf.reshape(room, shape=(9, -1))
mask = tf.map_fn(lambda x: [1, 1, 1] if x == [0, 0, 1] else [0, 0, 0], room)
but it ended with the following error:
ValueError: The two structures don't have the same number of elements. First structure: <dtype: 'int32'>, second structure: [0, 0, 0].
Upvotes: 3
Views: 2212
Reputation: 27070
map_fn
has the parameter dtype
that allows specifying the shape of the output if it's different from the shape of x
.
However, this is not the problem.
You're mixing python conditions into a tensorflow operation: the python operations (like the if) are executed outside the graph while, instead, you want to define a graph that executes the desired op.
Let's dig into your problem:
room
variable along the first dimension: map_fn
is OK for that.[0, 0, 1]
.To do this you need to use a tensorflow condition, tf.cond(pred, true_fn, false_fn)
.
Note that pred
must be a scalar. Thus, let's check using tensorflow operations only, if the current row
equals to your desired row and reduce the result to a single scalar.
It this is true just return the constant value [1,1,1]
, otherwise [0,0,0]
.
mask = tf.map_fn(lambda row: tf.cond(
tf.equal(
tf.reduce_prod(tf.cast(tf.equal(row, tf.constant([0,0,1])), tf.int32)), 1),
lambda: tf.constant([1,1,1]),
lambda: tf.constant([0,0,0])), room)
Upvotes: 2