Reputation: 1640
I'm trying to finetune the pretrained Transformer-XL model transfo-xl-wt103
for a language modeling task. Therfore, I use the model class TransfoXLLMHeadModel
.
To iterate over my dataset I use the LMOrderedIterator
from the file tokenization_transfo_xl.py which yields a tensor with the data
and its target
for each batch (and the sequence length).
Let's assume the following data with batch_size = 1
and bptt = 8
:
data = tensor([[1,2,3,4,5,6,7,8]])
target = tensor([[2,3,4,5,6,7,8,9]])
mems # from the previous output
My question is: I currently pass this data into the model like this:
output = model(input_ids=data, labels=target, mems=mems)
Is this correct?
I am wondering because the documentation says for the labels
parameter:
labels (:obj:
torch.LongTensor
of shape :obj:(batch_size, sequence_length)
,optional
, defaults to :obj:None
): Labels for language modeling. Note that the labels are shifted inside the model, i.e. you can setlm_labels = input_ids
So what is it about the parameter lm_labels
? I only see labels
defined in the forward
method.
And when the labels "are shifted" inside the model, does this mean I have to pass data
twice (additionally instead of targets
) because its shifted inside? But how does the model then know the next token to predict?
I also read through this bug and the fix in this pull request but I don't quite understand how to treat the model now (before vs. after fix)
Thanks in advance for some help!
Edit: Link to issue on Github
Upvotes: 0
Views: 662
Reputation: 36
That does sound like a typo from another model's convention. You do have to pass data twice, once to input_ids
and once to labels
(in your case, [1, ... , 8]
for both). The model will then attempt to predict [2, ... , 8]
from [1, ... , 7]
). I am not sure adding something at the beginning of the target tensor would work as that would probably cause size mismatches later down the line.
Passing twice is the default way to do this in transformers; before the aforementioned PR, TransfoXL
did not shift labels internally and you had to shift the labels yourself. The PR changed it to be consistent with the library and the documentation, where you have to pass the same data twice.
Upvotes: 2