Reputation: 133
I have started using the tf.data.Dataset
as a way to load data into keras models, as they appear to be much faster than keras' ImageDataGenerator
and much more memory efficient than training on arrays.
One think I can't get my head around is that I can't seem to find a way to access the len()
of the dataset. Keras' ImageDataGenerator
has an attribute called n
which I used to use for this purpose. This makes my code very ugly, as I need to hard-code the length in various parts of the scripy (e.g. to find out how many iterations an epoch has).
Any ideas I can work around this issue?
An example script:
# Generator
def make_mnist_train_generator(batch_size):
(x_train, y_train), (_,_) = tf.keras.mnist.load_data()
x_train = x_train.reshape((-1, 28, 28, 1))
x_train = x_train.astype(np.float32) / 225.
y_train = tf.keras.utils.to_categorical(y_train, 10)
ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
ds = ds.shuffle(buffer_size=len(x_train))
ds = ds.repeat()
ds = ds.batch(batch_size=batch_size)
ds = ds.prefetch(buffer_size=1)
return ds
model = ... # create a tf.keras model
batch_size 256
gen = make_mnist_train_generator(batch_size)
# Training
model.fit(gen, epochs=50, steps_per_epoch=60000//batch_size+1) # Hard coded size of generator
Upvotes: 3
Views: 1485
Reputation: 7432
Unfortunately tf.data.Dataset
is a generator and there is no inherent way of finding its size.
Generally speaking, when you use .from_tensor_slices()
you have a way of knowing its size by the argument you add in this method, in your case x_train
. Your only issue is that you are creating it inside a function.
A neat hack you can do to bypass this issue is to add a __len__
attribute on your own! The easiest way I've found that you can do this is:
ds.__class__ = type(ds.__class__.__name__, (ds.__class__,), {'__len__': lambda self: len(x_train)})
In your case it would look something like this:
def make_mnist_train_generator(batch_size):
(x_train, y_train), (_,_) = tf.keras.mnist.load_data()
x_train = x_train.reshape((-1, 28, 28, 1))
x_train = x_train.astype(np.float32) / 225.
y_train = tf.keras.utils.to_categorical(y_train, 10)
ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
ds = ds.shuffle(buffer_size=len(x_train))
ds = ds.repeat()
ds = ds.batch(batch_size=batch_size)
ds = ds.prefetch(buffer_size=1)
ds.__class__ = type(ds.__class__.__name__, (ds.__class__,), {'__len__': lambda self: len(x_train)})
return ds
gen = make_mnist_train_generator(batch_size)
model.fit(gen, epochs=50, steps_per_epoch=len(gen)//batch_size+1) # Hard coded size of generator
I've done this in the past and its surprisingly useful. There are many reasons why you'd want your generator to have a len()
. Some examples are:
Upvotes: 2