zainul muttaqin
zainul muttaqin

Reputation: 23

Error when using tf.data.Dataset.from_generator

I am trying to make tensorflow dataset using tensorflow from_generator, I am quite sure that I have made a python generator that work perfectly fine, but when I tried to pass it to from_generator I always got an error. this is the piece of code that I use to create the dataset

def dataset_generator(X, Y):
    for idx in range(X.shape[0]):
        img = X[idx, :, :, :]
        labels = Y[idx, :]
        yield img, labels

import tensorflow as tf
ds_generator = dataset_generator(X_data, Y_data)
ds = tf.data.Dataset.from_generator(ds_generator, output_signature=(tf.TensorSpec(shape=[None, 720, 720, 3], dtype=tf.int32), tf.TensorSpec(shape=[None, 30], dtype=tf.float16)))

but when I run it, it always produce error

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-63-af75191f4a28> in <module>
      1 import tensorflow as tf
      2 ds_generator = dataset_generator(X_data, Y_data)
----> 3 ds = tf.data.Dataset.from_generator(ds_generator, output_signature=(tf.TensorSpec(shape=[None, 720, 720, 3], dtype=tf.int32), tf.TensorSpec(shape=[None, 30], dtype=tf.float16)))

~/.local/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)

~/.local/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py in from_generator(generator, output_types, output_shapes, args, output_signature)

TypeError: `generator` must be callable.

Upvotes: 1

Views: 3307

Answers (1)

Edwin Cheong
Edwin Cheong

Reputation: 979

Hi the problem with your gen function is that you have to pass it as such via the args command, not as function as such

import tensorflow as tf
import numpy as np

# Gen Function
def dataset_generator(X, Y):
    for idx in range(X.shape[0]):
        img = X[idx, :, :, :]
        labels = Y[idx, :]
        yield img, labels

# Created random data for testing
X_data = np.random.randn(100, 720, 720, 3).astype(np.float32)
Y_data = tf.one_hot(np.random.randint(0, 30, (100, )), 30)

# Testing function
ds = tf.data.Dataset.from_generator(
    dataset_generator,
    args=(X_data, Y_data), 
    output_types=(tf.float32, tf.uint8)
)

# Get output
next(iter(ds.batch(10).take(1)))

Upvotes: 2

Related Questions