Reputation: 31
I'm trying to make a tensor with all the points between a certain range. For example
min_x = 5
max_x = 7
min_y = 3
max_y = 5
points = get_points(min_x, max_x, min_y, max_y)
print(point) # [[5, 3], [5, 4], [5, 5], [6, 3], [6, 4], [6, 5], [7, 3], [7, 4], [7, 5]]
I'm trying to do this inside a tensorflow function. AKA @tf.function
Also all the inputs to get_points need to be tensors.
Thanks, I'm new to tensorflow as you can tell.
Upvotes: 0
Views: 476
Reputation: 4475
You can use tf.meshgrid
, then stack x
and y
along the last dim after reshaping these two tensors.
min_x = 5
max_x = 7
min_y = 3
max_y = 5
def get_points(min_x, max_x, min_y, max_y):
x, y = tf.meshgrid(tf.range(min_x, max_x+1),tf.range(min_y, max_y+1))
_x = tf.reshape(x, (-1,1))
_y = tf.reshape(y, (-1,1))
return tf.squeeze(tf.stack([_x, _y], axis=-1))
res = get_points(min_x, max_x, min_y, max_y)
K.eval(res)
# array([[5, 3],
# [6, 3],
# [7, 3],
# [5, 4],
# [6, 4],
# [7, 4],
# [5, 5],
# [6, 5],
# [7, 5]], dtype=int32)
Upvotes: 2