Ray Tayek
Ray Tayek

Reputation: 10003

how to use map with tuples in a tensorflow 2 dataset?

trying to map a tuple to a tuple in a dataset in tf 2 (please see code below). my output (please see below) shows that the map function is only called once. and i can not seem to get at the tuple.

how do i get at the "a","b","c" from the input parameter which is a:

tuple Tensor("args_0:0", shape=(3,), dtype=string)
type <class 'tensorflow.python.framework.ops.Tensor'>

edit: it seems like using Dataset.from_tensor_slices produces the data all at once. this explcains why map is only called once. so i probably need to make the dataset in some other way.

from __future__ import absolute_import, division, print_function, unicode_literals
from timeit import default_timer as timer
print('import tensorflow')
start = timer()
import tensorflow as tf
end = timer()
print('Elapsed time: ' + str(end - start),"for",tf.__version__)
import numpy as np
def map1(tuple):
    print("<<<")
    print("tuple",tuple)
    print("type",type(tuple))
    print("shape",tuple.shape)
    print("tuple 0",tuple[0])
    print("type 0",type(tuple[0]))
    print("shape 0",tuple.shape[0])
    # how do i get "a","b","c" from the input parameter?
    print(">>>")
    return ("1","2","3")
l=[]
l.append(("a","b","c"))
l.append(("d","e","f"))
print(l)
ds=tf.data.Dataset.from_tensor_slices(l)
print("ds",ds)
print("start mapping")
result = ds.map(map1)
print("end mapping")


$ py mapds.py
import tensorflow
Elapsed time: 12.002168990751619 for 2.0.0
[('a', 'b', 'c'), ('d', 'e', 'f')]
ds <TensorSliceDataset shapes: (3,), types: tf.string>
start mapping
<<<
tuple Tensor("args_0:0", shape=(3,), dtype=string)
type <class 'tensorflow.python.framework.ops.Tensor'>
shape (3,)
tuple 0 Tensor("strided_slice:0", shape=(), dtype=string)
type 0 <class 'tensorflow.python.framework.ops.Tensor'>
shape 0 3
>>>
end mapping

Upvotes: 3

Views: 5133

Answers (1)

Kaushik Roy
Kaushik Roy

Reputation: 1685

The value or values returned by map function (map1) determine the structure of each element in the returned dataset. [Ref]
In your case, result is a tf dataset and there is nothing wrong in your coding.

To check whether every touple is mapped correctly you can traverse every sample of your dataset like follows:
[Updated Code]

    def map1(tuple):
        print(tuple[0].numpy().decode("utf-8")) # Print first element of tuple
        return ("1","2","3")
    l=[]
    l.append(("a","b","c"))
    l.append(("d","e","f"))
    ds=tf.data.Dataset.from_tensor_slices(l)
    ds = ds.map(lambda tpl: tf.py_function(map1, [tpl], [tf.string, tf.string, tf.string]))

    for sample in ds:
        print(str(sample[0].numpy().decode()), sample[1].numpy().decode(), sample[2].numpy().decode())
Output:
a
1 2 3
d
1 2 3

Hope it will help.

Upvotes: 2

Related Questions