krishnab
krishnab

Reputation: 10100

Tensorflow: Count number of examples in a TFRecord file -- without using deprecated `tf.python_io.tf_record_iterator`

Please read post before marking Duplicate:

I was looking for an efficient way to count the number of examples in a TFRecord file of images. Since a TFRecord file does not save any metadata about the file itself, the user has to loop through the file in order to calculate this information.

There are a few different questions on StackOverflow that answer this question. The problem is that all of them seem to use the DEPRECATED tf.python_io.tf_record_iterator command, so this is not a stable solution. Here is the sample of existing posts:

Obtaining total number of records from .tfrecords file in Tensorflow

Number of examples in each tfrecord

Number of examples in each tfrecord

So I was wondering if there was a way to count the number of records using the new Dataset API.

Upvotes: 7

Views: 5271

Answers (3)

James Adams
James Adams

Reputation: 8747

The following works for me using TensorFlow version 2.1 (using the code found in this answer):

def count_tfrecord_examples(
        tfrecords_dir: str,
) -> int:
    """
    Counts the total number of examples in a collection of TFRecord files.

    :param tfrecords_dir: directory that is assumed to contain only TFRecord files
    :return: the total number of examples in the collection of TFRecord files
        found in the specified directory
    """

    count = 0
    for file_name in os.listdir(tfrecords_dir):
        tfrecord_path = os.path.join(tfrecords_dir, file_name)
        count += sum(1 for _ in tf.data.TFRecordDataset(tfrecord_path))

    return count

Upvotes: 0

krishnab
krishnab

Reputation: 10100

I got the following code to work without the deprecated command. Hopefully this will help others.

Using the Dataset API I setup and iterator and then loop over it. Not sure if this is the fastest, but it works. MAKE SURE THE BATCH SIZE AND REPEAT ARE SET TO 1, otherwise the code will return the number of batches and not the number of examples in the dataset.

count_test = tf.data.TFRecordDataset('testing.tfrecord')
count_test = count_test.map(_parse_image_function)
count_test = count_test.repeat(1)
count_test = count_test.batch(1)
test_counter = count_test.make_one_shot_iterator()

c = 0
for ex in test_counter:
    c += 1
f"There are {c} testing records"

This seemed to work reasonably well even on a relatively large file.

Upvotes: 0

Maosi Chen
Maosi Chen

Reputation: 1491

There is a reduce method listed under the Dataset class. They give an example of counting records using the method:

# generate the dataset (batch size and repeat must be 1, maybe avoid dataset manipulation like map and shard)
ds = tf.data.Dataset.range(5) 
# count the examples by reduce
cnt = ds.reduce(np.int64(0), lambda x, _: x + 1)

## produces 5

Don't know whether this method is faster than the @krishnab's for loop.

Upvotes: 10

Related Questions