Bipin M
Bipin M

Reputation: 33

tf.data.Dataset - behavior of map() and cache() methods

Questions wrt. Tensorflow Datasets
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#cache

  1. How does the map function actually work ? The print(rand) in mapfn() prints just one value but print(x) prints values as expected
  2. Why does the map function behave differently compared to python map() function
    • dataset.map(mapfn) prints only 1 value
    • map(mapfn, numbers) prints 4 values
  3. When I get same result for x and y below, what is purpose of using dataset.cache() ?
import tensorflow as tf
from random import random
from math import ceil

def mapfn(x):
    rand = ceil(5*random())
    print(rand)
    return x**rand

dataset = tf.data.Dataset.range(50)
dataset = dataset.map(mapfn)
# dataset = dataset.cache()

x = list(dataset.as_numpy_iterator())
print(x)

y = list(dataset.as_numpy_iterator())
print(y)

vs

def mapfn(n):
    rand = ceil(5*random())
    print(rand)
    return n**rand
  
numbers = [1, 2, 3, 4]
result = map(mapfn, numbers)
print(list(result))

Upvotes: 3

Views: 1079

Answers (1)

Laplace Ricky
Laplace Ricky

Reputation: 1687

When you pass mapfn into dataset.map(), mapfn is converted into tensorflow graph and print() will not work as expected in graph mode. print() will only print during tracing stage, i.e., if mapfn is traced once, then it will only print once.

To print debug messages properly in graph mode, you need to use tf.print() instead.

If cache() is attached after dataset.map(mapfn), then it will cached the mapped values and the cached values will be used instead afterwards. (the memory has to be enough to hold all the cached values)

In other words, after the first loop over the dataset, mapfn will never be called again.

See Example:

ds=tf.data.Dataset.range(3)
def mapfn(x):
  tf.print('I am called')
  return tf.pow(x,2) #mapfn needs to be graph-mode compatible

ds=ds.map(mapfn)
print('First loop:')
for x in ds:
  print(x)
print()
print('Second loop:')
for x in ds:
  print(x)
print()

ds=ds.cache()
print('After cache():')
print('First loop:')
for x in ds:
  print(x)
print()
print('Second loop:')
for x in ds:
  print(x)
print()

'''
First loop:
I am called
tf.Tensor(0, shape=(), dtype=int64)
I am called
tf.Tensor(1, shape=(), dtype=int64)
I am called
tf.Tensor(4, shape=(), dtype=int64)

Second loop:
I am called
tf.Tensor(0, shape=(), dtype=int64)
I am called
tf.Tensor(1, shape=(), dtype=int64)
I am called
tf.Tensor(4, shape=(), dtype=int64)

After cache():
First loop:
I am called
tf.Tensor(0, shape=(), dtype=int64)
I am called
tf.Tensor(1, shape=(), dtype=int64)
I am called
tf.Tensor(4, shape=(), dtype=int64)

Second loop:
tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(4, shape=(), dtype=int64)
'''

Upvotes: 4

Related Questions