HuckleberryFinn
HuckleberryFinn

Reputation: 1529

Obtaining total number of records from .tfrecords file in Tensorflow

Is it possible for obtain the total number of records from a .tfrecords file ? Related to this, how does one generally keep track of the number of epochs that have elapsed while training models? While it is possible for us to specify the batch_size and num_of_epochs, I am not sure if it is straightforward to obtain values such as current epoch, number of batches per epoch etc - just so that I could have more control of how the training is progressing. Currently, I'm just using a dirty hack to compute this as I know before hand how many records there are in my .tfrecords file and the size of my minibatches. Appreciate any help..

Upvotes: 33

Views: 26124

Answers (5)

bhargav ram
bhargav ram

Reputation: 21

As tf.enable_eager_execution() is no longer valid, use:

tf.compat.v1.enable_eager_execution

sum(1 for _ in tf.data.TFRecordDataset(FILENAMES))

Upvotes: 1

Russell
Russell

Reputation: 1451

As per the deprecation warning on tf_record_iterator, we can also use eager execution to count records.

#!/usr/bin/env python
from __future__ import print_function

import tensorflow as tf
import sys

assert len(sys.argv) == 2, \
    "USAGE: {} <file_glob>".format(sys.argv[0])

tf.enable_eager_execution()

input_pattern = sys.argv[1]

# Expand glob if there is one
input_files = tf.io.gfile.glob(input_pattern)

# Create the dataset
data_set = tf.data.TFRecordDataset(input_files)

# Count the records
records_n = sum(1 for record in data_set)

print("records_n = {}".format(records_n))

Upvotes: 6

BiBi
BiBi

Reputation: 7908

As tf.io.tf_record_iterator is being deprecated, the great answer of Salvador Dali should now read

tf.enable_eager_execution()
sum(1 for _ in tf.data.TFRecordDataset(file_name))

Upvotes: 12

Salvador Dali
Salvador Dali

Reputation: 222889

No it is not possible. TFRecord does not store any metadata about the data being stored inside. This file

represents a sequence of (binary) strings. The format is not random access, so it is suitable for streaming large amounts of data but not suitable if fast sharding or other non-sequential access is desired.

If you want, you can store this metadata manually or use a record_iterator to get the number (you will need to iterate through all the records that you have:

sum(1 for _ in tf.python_io.tf_record_iterator(file_name))

If you want to know the current epoch, you can do this either from tensorboard or by printing the number from the loop.

Upvotes: 26

drpng
drpng

Reputation: 1637

To count the number of records, you should be able to use tf.python_io.tf_record_iterator.

c = 0
for fn in tf_records_filenames:
  for record in tf.python_io.tf_record_iterator(fn):
     c += 1

To just keep track of the model training, tensorboard comes in handy.

Upvotes: 33

Related Questions