Johnson
Johnson

Reputation: 133

How to find the len() of a tf.Dataset

I have started using the tf.data.Dataset as a way to load data into keras models, as they appear to be much faster than keras' ImageDataGenerator and much more memory efficient than training on arrays.

One think I can't get my head around is that I can't seem to find a way to access the len() of the dataset. Keras' ImageDataGenerator has an attribute called n which I used to use for this purpose. This makes my code very ugly, as I need to hard-code the length in various parts of the scripy (e.g. to find out how many iterations an epoch has).

Any ideas I can work around this issue?

An example script:

# Generator
def make_mnist_train_generator(batch_size):
    (x_train, y_train), (_,_) = tf.keras.mnist.load_data()

    x_train = x_train.reshape((-1, 28, 28, 1))
    x_train = x_train.astype(np.float32) / 225.

    y_train = tf.keras.utils.to_categorical(y_train, 10)

    ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    ds = ds.shuffle(buffer_size=len(x_train))
    ds = ds.repeat()
    ds = ds.batch(batch_size=batch_size)
    ds = ds.prefetch(buffer_size=1)

    return ds


model = ...  # create a tf.keras model

batch_size 256
gen = make_mnist_train_generator(batch_size)

# Training
model.fit(gen, epochs=50, steps_per_epoch=60000//batch_size+1)  # Hard coded size of generator

Upvotes: 3

Views: 1485

Answers (1)

Djib2011
Djib2011

Reputation: 7432

tl;dr

Unfortunately tf.data.Dataset is a generator and there is no inherent way of finding its size.

But...

Generally speaking, when you use .from_tensor_slices() you have a way of knowing its size by the argument you add in this method, in your case x_train. Your only issue is that you are creating it inside a function.

A neat hack you can do to bypass this issue is to add a __len__ attribute on your own! The easiest way I've found that you can do this is:

ds.__class__ = type(ds.__class__.__name__, (ds.__class__,), {'__len__': lambda self: len(x_train)})

In your case it would look something like this:

def make_mnist_train_generator(batch_size):
    (x_train, y_train), (_,_) = tf.keras.mnist.load_data()

    x_train = x_train.reshape((-1, 28, 28, 1))
    x_train = x_train.astype(np.float32) / 225.

    y_train = tf.keras.utils.to_categorical(y_train, 10)

    ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
    ds = ds.shuffle(buffer_size=len(x_train))
    ds = ds.repeat()
    ds = ds.batch(batch_size=batch_size)
    ds = ds.prefetch(buffer_size=1)

    ds.__class__ = type(ds.__class__.__name__, (ds.__class__,), {'__len__': lambda self: len(x_train)})

    return ds


gen = make_mnist_train_generator(batch_size)

model.fit(gen, epochs=50, steps_per_epoch=len(gen)//batch_size+1)  # Hard coded size of generator

Why do this?

I've done this in the past and its surprisingly useful. There are many reasons why you'd want your generator to have a len(). Some examples are:

  • if you want to have the generator in a separate module and import it
  • if the generator is meant to be used by someone else that doesn't know what data was used to create it

Upvotes: 2

Related Questions