Kurt
Kurt

Reputation: 306

bert_vocab.bert_vocab_from_dataset taking too long

I'm following this tutorial (https://colab.research.google.com/github/tensorflow/text/blob/master/docs/guide/subwords_tokenizer.ipynb#scrollTo=kh98DvoDz7Jn) to generate a vocabulary from a custom dataset. In the tutorial, it takes around 2 minutes for this code to complete:

bert_vocab_args = dict(
    # The target vocabulary size
    vocab_size = 8000,
    # Reserved tokens that must be included in the vocabulary
    reserved_tokens=reserved_tokens,
    # Arguments for `text.BertTokenizer`
    bert_tokenizer_params=bert_tokenizer_params,
    # Arguments for `wordpiece_vocab.wordpiece_tokenizer_learner_lib.learn`
    learn_params={},
)

pt_vocab = bert_vocab.bert_vocab_from_dataset(
    train_pt.batch(1000).prefetch(2),
    **bert_vocab_args
)

On my dataset it takes a lot longer... I tried increasing the batch number as well as decreasing the size of the vocabulary, all to no avail. Is there any way to make this go faster?

Upvotes: 0

Views: 416

Answers (1)

rebo
rebo

Reputation: 11

I ran into the same issue. This is how I resolved it:

First I checked the number of elements in the dataset:

examples, metadata = tfds.load('my_dataset', as_supervised=True, with_info=True)
print(metadata)

In my case, the dataset contained more than 5 million elements, which explains why creating the vocabulary took an endless amount of time.

The portuguese vocabulary of the tensorflow example is built using some 50000 elements. So I selected 1% of my dataset:

train_tokenize, metadata = tfds.load('my_dataset', split='train[:1%]',as_supervised=True, with_info=True)

I then used this dataset to develop the vocabulary, which took some 2 minutes:

train_en_tokenize = train_tokenize.map(lambda en, ol: en)
train_ol_tokenize = train_tokenize.map(lambda en, ol: ol)

ol_vocab = bert_vocab.bert_vocab_from_dataset(
    train_ol_tokenize.batch(1000).prefetch(2),
    **bert_vocab_args
)
en_vocab = bert_vocab.bert_vocab_from_dataset(
    train_en_tokenize.batch(1000).prefetch(2),
    **bert_vocab_args
)

where ol stands for the 'other language' I am developing the model for.

Upvotes: 1

Related Questions