freak11
freak11

Reputation: 391

Finding the maximum value in an indivudal batch in tensorflow

Suppose you have the following code below. I want to find the max value in the tensorflow dataset and then add it to the set. Something like map(lambda x: x + 1 + max(x)). Any ideas how to implement it as I get an error message?

dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
dataset = dataset.map(lambda x: x + 1)
list(dataset.as_numpy_iterator())

Upvotes: 0

Views: 127

Answers (1)

Nicolas Gervais
Nicolas Gervais

Reputation: 36624

import tensorflow as tf

dataset = tf.data.Dataset.range(1, 25 + 1).batch(5)
dataset = dataset.map(lambda x: tf.concat([x, [tf.reduce_max(x, axis=0)]], axis=0))

for i in dataset:
    print(i)
tf.Tensor([1 2 3 4 5 5], shape=(6,), dtype=int64)
tf.Tensor([ 6  7  8  9 10 10], shape=(6,), dtype=int64)
tf.Tensor([11 12 13 14 15 15], shape=(6,), dtype=int64)
tf.Tensor([16 17 18 19 20 20], shape=(6,), dtype=int64)
tf.Tensor([21 22 23 24 25 25], shape=(6,), dtype=int64)

Upvotes: 1

Related Questions