Reputation: 95
I am trying to fine tune the T5 transformer for summarization but I am receiving a key error message:
KeyError: 'Indexing with integers (to access backend Encoding for a given batch index) is not available when using Python based tokenizers'
The code I am using is basically this:
model_name = '...'
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
model.to(device)
(...)
df_dataset = df_dataset[['summary','document']]
df_dataset.document = 'summarize: ' + df_dataset.document
X = list(df_dataset['document'])
y = list(df_dataset['summary'])
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
X_train_tokenized = tokenizer(X_train, padding=True, truncation=True, max_length=512)
y_train_tokenized = tokenizer(y_train, padding=True, truncation=True, max_length=512)
X_val_tokenized = tokenizer(X_val, padding=True, truncation=True, max_length=512)
y_val_tokenized = tokenizer(y_val, padding=True, truncation=True, max_length=512)
# Create torch dataset
class Dataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
training_set = Dataset(X_train_tokenized, y_train_tokenized)
validation_set = Dataset(X_val_tokenized, y_val_tokenized)
# Define Trainer
args = TrainingArguments(
output_dir="output",
evaluation_strategy="steps",
eval_steps=500,
per_device_train_batch_size=TRAIN_BATCH_SIZE,
per_device_eval_batch_size=VALID_BATCH_SIZE,
num_train_epochs=TRAIN_EPOCHS,
save_steps=3000,
load_best_model_at_end = True,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=training_set,
eval_dataset=validation_set,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)
trainer.train()
And the full error:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-29-f31e4c5cde21> in <module>
1 # Train pre-trained model
----> 2 trainer.train()
c:\programdata\anaconda3\envs\summa\lib\site-packages\transformers\trainer.py in train(self, resume_from_checkpoint, trial, **kwargs)
1099 self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
1100
-> 1101 for step, inputs in enumerate(epoch_iterator):
1102
1103 # Skip past any already trained steps if resuming training
c:\programdata\anaconda3\envs\summa\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
515 if self._sampler_iter is None:
516 self._reset()
--> 517 data = self._next_data()
518 self._num_yielded += 1
519 if self._dataset_kind == _DatasetKind.Iterable and \
c:\programdata\anaconda3\envs\summa\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self)
555 def _next_data(self):
556 index = self._next_index() # may raise StopIteration
--> 557 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
558 if self._pin_memory:
559 data = _utils.pin_memory.pin_memory(data)
c:\programdata\anaconda3\envs\summa\lib\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
c:\programdata\anaconda3\envs\summa\lib\site-packages\torch\utils\data\_utils\fetch.py in <listcomp>(.0)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
<ipython-input-24-67979e648b75> in __getitem__(self, idx)
7 def __getitem__(self, idx):
8 item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
----> 9 item['labels'] = torch.tensor(self.labels[idx])
10 return item
11
c:\programdata\anaconda3\envs\summa\lib\site-packages\transformers\tokenization_utils_base.py in __getitem__(self, item)
232 return self._encodings[item]
233 else:
--> 234 raise KeyError(
235 "Indexing with integers (to access backend Encoding for a given batch index) "
236 "is not available when using Python based tokenizers"
KeyError: 'Indexing with integers (to access backend Encoding for a given batch index) is not available when using Python based tokenizers'
And if change the line:
tokenizer = T5Tokenizer.from_pretrained(model_name)
To:
tokenizer = T5TokenizerFast.from_pretrained(model_name)
the error changes to:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-28-f31e4c5cde21> in <module>
1 # Train pre-trained model
----> 2 trainer.train()
c:\programdata\anaconda3\envs\summa\lib\site-packages\transformers\trainer.py in train(self, resume_from_checkpoint, trial, **kwargs)
1099 self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
1100
-> 1101 for step, inputs in enumerate(epoch_iterator):
1102
1103 # Skip past any already trained steps if resuming training
c:\programdata\anaconda3\envs\summa\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
515 if self._sampler_iter is None:
516 self._reset()
--> 517 data = self._next_data()
518 self._num_yielded += 1
519 if self._dataset_kind == _DatasetKind.Iterable and \
c:\programdata\anaconda3\envs\summa\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self)
555 def _next_data(self):
556 index = self._next_index() # may raise StopIteration
--> 557 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
558 if self._pin_memory:
559 data = _utils.pin_memory.pin_memory(data)
c:\programdata\anaconda3\envs\summa\lib\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
c:\programdata\anaconda3\envs\summa\lib\site-packages\torch\utils\data\_utils\fetch.py in <listcomp>(.0)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
<ipython-input-23-67979e648b75> in __getitem__(self, idx)
7 def __getitem__(self, idx):
8 item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
----> 9 item['labels'] = torch.tensor(self.labels[idx])
10 return item
11
RuntimeError: Could not infer dtype of tokenizers.Encoding
Any idea of what is wrong?
Upvotes: 8
Views: 9037
Reputation: 199
This is because this tokenizer returns an object with the following structure
You have to amend the __getitem__
method of your dataset class along the lines of
class ForT5Dataset(torch.utils.data.Dataset):
def __init__(self, inputs, targets):
self.inputs = inputs
self.targets = targets
def __len__(self):
return len(self.targets)
def __getitem__(self, index):
input_ids = torch.tensor(self.inputs["input_ids"][index]).squeeze()
target_ids = torch.tensor(self.targets["input_ids"][index]).squeeze()
return {"input_ids": input_ids, "labels": target_ids}
and pass data prop when initializing, like:
train_ds = ForT5Dataset(train_in.data, train_out.data)
.
Upvotes: 3