Pablo Messina
Pablo Messina

Reputation: 441

Pytorch: How to implement nested transformers: a character-level transformer for words and a word-level transformer for sentences?

I have a model in mind, but I'm having a hard time figuring out how to actually implement it in Pytorch, especially when it comes to training the model (e.g. how to define mini-batches, etc.). First of all let me quickly introduce the context:

I'm working on VQA (visual question answering), in which the task is to answer questions about images, for example:

enter image description here

So, letting aside many details, I just want to focus here on the NLP aspect/branch of the model. In order to process the natural language question, I want to use character-level embeddings (instead of traditional word-level embeddings) because they are more robust in the sense that they can easily accommodate for morphological variations in words (e.g. prefixes, suffixes, plurals, verb conjugations, hyphens, etc.). But at the same time I don't want to lose the inductive bias of reasoning at the word level. Therefore, I came up with the following design:

enter image description here

As you can see in the picture above, I want to use transformers (or even better, universal transformers), but with a little twist. I want to use 2 transformers: the first one will process each word characters in isolation (character-level transformer) to produce an initial word-level embedding for each word in the question. Once we have all these initial word-level embeddings, a second word-level transformer will refine these embeddings to enrich their representation with context, thus obtaining context-aware word-level embeddings.

The full model for the whole VQA task obviously is more complex, but I just want to focus here on this NLP part. So my question is basically about which Pytorch functions should I pay attention to when implementing this. For example, since I'll be using character-level embeddings I have to define a character-level embedding matrix, but then I have to perform lookups on this matrix to generate the inputs for the character-level transformer, repeat this for each word in the question and then feed all these vectors into the word-level transformer. Moreover, words in a single question can have different lengths, and questions within a single mini-batch can have different lengths too. So in my code I have to somehow account for different lengths at the word and the question level simultaneously in a single mini-batch (during training), and I've got no idea how to do that in Pytorch or whether it's even possible at all.

Any tips on how to go about implementing this in Pytorch that could lead me in the right direction will be deeply appreciated.

Upvotes: 5

Views: 1487

Answers (1)

caspillaga
caspillaga

Reputation: 563

A way to implement what you say in pyTorch would require adapting the Transformer encoder:

1) Define a custom tokenizer that splits words into character embeddings (instead of word or word-piece embeddings)

2) Define a mask for each word (similar to what the original paper used to mask future tokens in the decoder), in order to force the model to be constrained to the word-context (in the first stage)

3) Then use a traditional Transformer with the mask (effectively restricting word-level context).

4) Then discard the mask and apply Transformer again (sentence-level context).

.

Things to be careful about:

1) Remember that Transformer encoder's output length is always the same size as the input (the decoder is the one able to produce longer or shorter sequences). So in your first stage, you will not have word-level embeddings (as shown in your diagram) but character level embeddings. If you want to merge them into word level embeddings, you will need an additional intermediate decoder step or merge the embeddings using a custom strategy (ex: a learnt weighted sum or using something similar to BERT's token).

2) You may face efficiency issues. Remember that Transformer is O(n^2), so the longer the sequence, the more computationally-expensive it is. In the original Transformer, if you had a sentence of length 10 words, then the Thansformer will have to deal with a 10-token sequence. If you use word-piece embeddings, your model will work at around ~15-token sequences. But if you use character-level embeddings, I estimate that you will be dealing with ~50-token sequences, which may not be feasible for long sentences, so you may need to truncate your input (and you will be losing all the long-term dependency power of attention models).

3) Are you sure that you will have a representational contribution by adding the character-level Transformer? Transformer aims to enrich embeddings based on the context (surrounding embeddings), that's why the original implementation used word-level embeddings. BERT uses word-piece embeddings, to take advantage of language regularities in related words and GPT-2 uses Byte-Pais-Embeddings (BPE), which I don't recommend in your case, because it is more suited for next-token prediction. In your case, what information do you think will be captured at the learnt character embeddings so that it can be effectively shared between the characters of the word? Do you think it will be richer than using a learnt embedding for each word or word-piece? My guess is that this is what you are trying to find out... right?

Upvotes: 4

Related Questions