Reputation: 31
I am quite new to the BERT language model. I am currently using the Huggingface transformer libraryand i'm encountering an error when encoding the inputs. The goal of the model is to classify fake news.
First I downloaded the dataset which I turned into a pandas dataframe containing 3 columns. Index, tweet, label. The pretrained auto tokenizer from bert large uncased is used to encode the input.
TOKENIZER = AutoTokenizer.from_pretrained("bert-large-uncased")
The following function is used:
def bert_encode(data,maximum_len) :
input_ids = []
attention_masks = []
for i in range(len(data.tweet)):
encoded = TOKENIZER.encode_plus(data.tweet[i],
add_special_tokens=True,
max_length=maximum_len,
pad_to_max_length=True,
return_attention_mask=True,
truncation=True)
input_ids.append(encoded['input_ids'])
attention_masks.append(encoded['attention_mask'])
return np.array(input_ids),np.array(attention_masks)
The function is applied to the the data to get the train input id and the attention masks:
train_input_ids,train_attention_masks = bert_encode(train,600)
test_input_ids,test_attention_masks = bert_encode(test,600)
However, calling the function gives me the following error: KeyError: 3 Provided beolow is the exact error message.
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
/usr/local/lib/python3.7/dist-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
2897 try:
-> 2898 return self._engine.get_loc(casted_key)
2899 except KeyError as err:
pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()
pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()
pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.Int64HashTable.get_item()
pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.Int64HashTable.get_item()
KeyError: 3
The above exception was the direct cause of the following exception:
KeyError Traceback (most recent call last)
4 frames
/usr/local/lib/python3.7/dist-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
2898 return self._engine.get_loc(casted_key)
2899 except KeyError as err:
-> 2900 raise KeyError(key) from err
2901
2902 if tolerance is not None:
KeyError: 3
Any insight on how to debug are welcome.
Upvotes: 3
Views: 2413
Reputation: 1895
Print the index using:
train.index
and test.index
Sometimes, the index is not sequential since you may have combined tables from different sources. You can fix this by typing
train.reset_index(drop=True, inplace=True)
test.reset_index(drop=True, inplace=True)
If you need to keep the original index for train
and test
, do this step before splitting into train
and test
.
Upvotes: 1