- Notifications
You must be signed in to change notification settings - Fork0
This repository extends a basic MLM implementation to allow for efficiently conditioning on chained previous texts, in a tree; for e.g., a Reddit thread.
License
Jeevesh8/AutoRegressive-MLM
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
We build uponethis original repo to allow for an AutoRegressive MLM model. In social media platforms, like Reddit, we see a tree of comments forming, with the post at the root. The task now is to condition the MLM prediction of a masked tokens in any particular comment on the post and all the comments on the shortest path from the root to the comment. We try to do the same, using an encoder-decoder like architecture, while making the following modifications :
The decoder's self attention is not masked, so it can look both forward and backward in the current comment for predicting masked tokens.
The encoder encodes all comments and posts first
Then each of those post's representation is pooled over. to get fixed length vector for each post and comment.
Then we pick a particular post, we mask tokens in it, send this masked sequence to the decoder.
And a sequence of fixed length vectors corresponding to each parent post of the currently picked post, are concatenated and sent in the place where the decoder accepts the encoder output.
To Try : Adding positional encodings to the output of the encoder.
Seehere for a detailed explanation and model architecture.
A follow up repository ofJax-Journey. This repository provides a selection of notebooks for various NLP tasks, which are completely see-through (i.e., you can see the implementation till the basic Jax/Haiku modules, in a single notebook). These were meant to be used as further tutorials in Jax for NLP, and as a guide for the coding style followed in thisawesome article by Madison May.
These notebooks, although mostly code, also mention the nuanced features, often missed when using off-the-shelf models. Moreover, they allow you to optimize everything right to the innermost modules. Also, we mention how to adapt the model to your use case, in each notebook.
A basic introductory notebook consisting of the originalRoBERTa initialized version andrandomly initialized version .
Here we realise the need for restructuring the code, and correspondingly, place all the code component-wise insrc/
. The new things we code over the original implementation are:
- The masking function for MLMhere,
- AHuggingFace Tokenizers based tokenizer,here
- A Language Embedding for TLM task,here.
- Additionally, we include an option to make the transformer auto-regressive and add a mask for the same,here. This is needed for CLM.
The final notebook can be foundhere.