cloudy
cloudy

Reputation: 86

tf.data.Dataset - delete cache?

Is it possible to delete the in-memory cache that's built after calling tf.data.Dataset.cache()?

Here's what I'd like to do. The augmentation for the dataset is very costly, so the current code is more or less:

data = tf.data.Dataset(...) \
       .map(<expensive_augmentation>) \
       .cache() \
       # .shuffle().batch() etc. 

However, this means that every iteration over data will see the same augmented versions of the data samples. What I'd like to do instead is to use the cache for a couple of epochs and then start over, or equivalently do something like Dataset.map(<augmentation>).fleeting_cache().repeat(8). Is this possible to achieve?

Upvotes: 6

Views: 2061

Answers (1)

AAudibert
AAudibert

Reputation: 1273

The cache lifecycle is tied to the dataset, so you can achieve this by re-creating the dataset:

def create_dataset():
  dataset = tf.data.Dataset(...)
  dataset = dataset.map(<expensive_augmentation>)
  dataset = dataset.shuffle(...)
  dataset = dataset.batch(...)
  return dataset

for epoch in range(num_epochs):
  # Drop the cache every 8 epochs.
  if epoch % 8 == 0: dataset = create_dataset()
  for batch in dataset:
    train(batch)

Upvotes: 1

Related Questions