SantoshGupta7
SantoshGupta7

Reputation: 6197

How to solve data fetch bottle neck for TPU inference?

This is what my inference setup looks like

autotune = tf.data.experimental.AUTOTUNE

with strategy.scope():
    model = LoadModel()
    raw_dataset = tf.data.TFRecordDataset(tfRecordAddress)
    train_dataset = raw_dataset.map(_parse_example, num_parallel_calls=autotune)
    train_dataset = train_dataset.padded_batch(batch_size, padding_values=(1, 1, b'-'), padded_shapes=(512, 512, 1))
    # train_dataset = train_dataset.repeat()
    train_dataset = train_dataset.prefetch(autotune)
    train_dataset = strategy.experimental_distribute_dataset(train_dataset)

def per_core_inference_fn(inputIds,attnIds ):
    return model.inference((inputIds, attnIds))

@tf.function
def inference_fn(inputIds, attnIds):
    return strategy.run(per_core_inference_fn, args=(inputIds,attnIds))

results = []
for x in train_dataset:
    t0 = time.time()
    results.append(inference_fn(x[0], x[1]))
    t1 = time.time()
    print('time is :', t1-t0)

With huge batch_sizes, the inference is blazing fast, something like .0003 seconds. However, the fetching of the next batch takes a long time, for x in train_dataset:, like 60-80 seconds.

As far as I can tell, I am doing the inference correctly, but somehow the TPU's CPU is running into a huge bottleneck with the batch retrieval.

I did Not see this bottleneck during training. So it looks like model.fit is doing something I'm not.

Upvotes: 0

Views: 149

Answers (1)

Allen Wang
Allen Wang

Reputation: 301

I have a feeling that this bottleneck occurs specifically due to the for x in train_dataset. This 60-80 seconds between batch loading implies to me that the prefetch is not working as expected. In custom training loop (CTL) code, I typically see the entirety of the loop is wrapped in a tf.function, such as in here.

Could you modifying your code similarly? You can also try capturing a TPU profile (https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_profile) instead of using time.time() for benchmarking.

Upvotes: 1

Related Questions