Reputation: 33
Questions wrt. Tensorflow Datasets
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map
https://www.tensorflow.org/api_docs/python/tf/data/Dataset#cache
print(rand)
in mapfn()
prints just one value but print(x)
prints values as expecteddataset.map(mapfn)
prints only 1 valuemap(mapfn, numbers)
prints 4 valuesdataset.cache()
?import tensorflow as tf
from random import random
from math import ceil
def mapfn(x):
rand = ceil(5*random())
print(rand)
return x**rand
dataset = tf.data.Dataset.range(50)
dataset = dataset.map(mapfn)
# dataset = dataset.cache()
x = list(dataset.as_numpy_iterator())
print(x)
y = list(dataset.as_numpy_iterator())
print(y)
vs
def mapfn(n):
rand = ceil(5*random())
print(rand)
return n**rand
numbers = [1, 2, 3, 4]
result = map(mapfn, numbers)
print(list(result))
Upvotes: 3
Views: 1079
Reputation: 1687
When you pass mapfn
into dataset.map()
, mapfn
is converted into tensorflow graph and print()
will not work as expected in graph mode. print()
will only print during tracing stage, i.e., if mapfn
is traced once, then it will only print once.
To print debug messages properly in graph mode, you need to use tf.print()
instead.
If cache()
is attached after dataset.map(mapfn)
, then it will cached the mapped values and the cached values will be used instead afterwards. (the memory has to be enough to hold all the cached values)
In other words, after the first loop over the dataset, mapfn
will never be called again.
See Example:
ds=tf.data.Dataset.range(3)
def mapfn(x):
tf.print('I am called')
return tf.pow(x,2) #mapfn needs to be graph-mode compatible
ds=ds.map(mapfn)
print('First loop:')
for x in ds:
print(x)
print()
print('Second loop:')
for x in ds:
print(x)
print()
ds=ds.cache()
print('After cache():')
print('First loop:')
for x in ds:
print(x)
print()
print('Second loop:')
for x in ds:
print(x)
print()
'''
First loop:
I am called
tf.Tensor(0, shape=(), dtype=int64)
I am called
tf.Tensor(1, shape=(), dtype=int64)
I am called
tf.Tensor(4, shape=(), dtype=int64)
Second loop:
I am called
tf.Tensor(0, shape=(), dtype=int64)
I am called
tf.Tensor(1, shape=(), dtype=int64)
I am called
tf.Tensor(4, shape=(), dtype=int64)
After cache():
First loop:
I am called
tf.Tensor(0, shape=(), dtype=int64)
I am called
tf.Tensor(1, shape=(), dtype=int64)
I am called
tf.Tensor(4, shape=(), dtype=int64)
Second loop:
tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor(1, shape=(), dtype=int64)
tf.Tensor(4, shape=(), dtype=int64)
'''
Upvotes: 4