Johnpac
Johnpac

Reputation: 95

Key Error while fine tunning T5 for summarization with HuggingFace

I am trying to fine tune the T5 transformer for summarization but I am receiving a key error message:

KeyError: 'Indexing with integers (to access backend Encoding for a given batch index) is not available when using Python based tokenizers'

The code I am using is basically this:

model_name = '...'
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
model.to(device)

(...)

df_dataset = df_dataset[['summary','document']]
df_dataset.document = 'summarize: ' + df_dataset.document

X = list(df_dataset['document'])
y = list(df_dataset['summary'])
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
X_train_tokenized = tokenizer(X_train, padding=True, truncation=True, max_length=512)
y_train_tokenized = tokenizer(y_train, padding=True, truncation=True, max_length=512)
X_val_tokenized = tokenizer(X_val, padding=True, truncation=True, max_length=512)
y_val_tokenized = tokenizer(y_val, padding=True, truncation=True, max_length=512)

# Create torch dataset
class Dataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels) 

training_set = Dataset(X_train_tokenized, y_train_tokenized)
validation_set = Dataset(X_val_tokenized, y_val_tokenized)

# Define Trainer
args = TrainingArguments(
    output_dir="output",
    evaluation_strategy="steps",
    eval_steps=500,
    per_device_train_batch_size=TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=VALID_BATCH_SIZE,
    num_train_epochs=TRAIN_EPOCHS,
    save_steps=3000,
    load_best_model_at_end = True,    
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=training_set,
    eval_dataset=validation_set,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

trainer.train()

And the full error:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-29-f31e4c5cde21> in <module>
      1 # Train pre-trained model
----> 2 trainer.train()

c:\programdata\anaconda3\envs\summa\lib\site-packages\transformers\trainer.py in train(self, resume_from_checkpoint, trial, **kwargs)
   1099             self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
   1100 
-> 1101             for step, inputs in enumerate(epoch_iterator):
   1102 
   1103                 # Skip past any already trained steps if resuming training

c:\programdata\anaconda3\envs\summa\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
    515             if self._sampler_iter is None:
    516                 self._reset()
--> 517             data = self._next_data()
    518             self._num_yielded += 1
    519             if self._dataset_kind == _DatasetKind.Iterable and \

c:\programdata\anaconda3\envs\summa\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self)
    555     def _next_data(self):
    556         index = self._next_index()  # may raise StopIteration
--> 557         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    558         if self._pin_memory:
    559             data = _utils.pin_memory.pin_memory(data)

c:\programdata\anaconda3\envs\summa\lib\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

c:\programdata\anaconda3\envs\summa\lib\site-packages\torch\utils\data\_utils\fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

<ipython-input-24-67979e648b75> in __getitem__(self, idx)
      7     def __getitem__(self, idx):
      8         item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
----> 9         item['labels'] = torch.tensor(self.labels[idx])
     10         return item
     11 

c:\programdata\anaconda3\envs\summa\lib\site-packages\transformers\tokenization_utils_base.py in __getitem__(self, item)
    232             return self._encodings[item]
    233         else:
--> 234             raise KeyError(
    235                 "Indexing with integers (to access backend Encoding for a given batch index) "
    236                 "is not available when using Python based tokenizers"

KeyError: 'Indexing with integers (to access backend Encoding for a given batch index) is not available when using Python based tokenizers'

And if change the line:

tokenizer = T5Tokenizer.from_pretrained(model_name)

To:

tokenizer = T5TokenizerFast.from_pretrained(model_name)

the error changes to:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-28-f31e4c5cde21> in <module>
      1 # Train pre-trained model
----> 2 trainer.train()

c:\programdata\anaconda3\envs\summa\lib\site-packages\transformers\trainer.py in train(self, resume_from_checkpoint, trial, **kwargs)
   1099             self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
   1100 
-> 1101             for step, inputs in enumerate(epoch_iterator):
   1102 
   1103                 # Skip past any already trained steps if resuming training

c:\programdata\anaconda3\envs\summa\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
    515             if self._sampler_iter is None:
    516                 self._reset()
--> 517             data = self._next_data()
    518             self._num_yielded += 1
    519             if self._dataset_kind == _DatasetKind.Iterable and \

c:\programdata\anaconda3\envs\summa\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self)
    555     def _next_data(self):
    556         index = self._next_index()  # may raise StopIteration
--> 557         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    558         if self._pin_memory:
    559             data = _utils.pin_memory.pin_memory(data)

c:\programdata\anaconda3\envs\summa\lib\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

c:\programdata\anaconda3\envs\summa\lib\site-packages\torch\utils\data\_utils\fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

<ipython-input-23-67979e648b75> in __getitem__(self, idx)
      7     def __getitem__(self, idx):
      8         item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
----> 9         item['labels'] = torch.tensor(self.labels[idx])
     10         return item
     11 

RuntimeError: Could not infer dtype of tokenizers.Encoding

Any idea of what is wrong?

Upvotes: 8

Views: 9037

Answers (1)

leevanoetz
leevanoetz

Reputation: 199

This is because this tokenizer returns an object with the following structure Tokenizer outpu

You have to amend the __getitem__ method of your dataset class along the lines of

class ForT5Dataset(torch.utils.data.Dataset):
    def __init__(self, inputs, targets):
        self.inputs = inputs
        self.targets = targets
    
    def __len__(self):
        return len(self.targets)
    
    def __getitem__(self, index):
        input_ids = torch.tensor(self.inputs["input_ids"][index]).squeeze()
        target_ids = torch.tensor(self.targets["input_ids"][index]).squeeze()
        
        return {"input_ids": input_ids, "labels": target_ids}

and pass data prop when initializing, like: train_ds = ForT5Dataset(train_in.data, train_out.data).

Upvotes: 3

Related Questions