- Notifications
You must be signed in to change notification settings - Fork40
RWKV is a RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.
License
Blealtan/RWKV-LM-LoRA
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
This branch contains my experimental attempts to achieve infinite context training in RWKV.With this implementation you can train on arbitrarily long context within (near) constant VRAM consumption; the increasing should be, take RWKV 7B as an example, about 2MB per 1024/2048 tokens (depending on your chosenctx_len
) in the training sample, which will enable training on sequences over 1M tokens.Yet directly tune to such long sequences might be problematic; soctx_len_cutoff
is provided so longer sequences are sliced into multiple pieces of the specified cutoff size and learnt by the model separately.It can be later increased until no cutoff presents.
The training code is by the way tremendously refactored into using PyTorch 2.0, Lightning 2.0 and DeepSpeed 2.0, and the starting script now relies on LightningCLI so you will see theconfig.yaml containing all the switches, mostly standard ones that Lightning processes by itself.
To use this repo, go intoRWKV-v4neo
directory and do
python3 new_train.py fit -c {your_config}.yaml
Remember to modify the configuration for your own need.
SeeRWKV-v4neo/config-example.yaml for documentation on the various options
The following features are not yet supported (that may exist inblinks original repo)
- numpy file dataset
- binidx dataset
- model init weight
- model resize weights (init from smaller to bigger model)
- world tokenizer
- Learning Rate init -> Learning Rate Final support
- helper script to add new tokens to existing model
The following venv setup using conda, modify for your use case respectively
# ninja-build is required for the new trainersudo apt-get install ninja-build# Virtual env, with python 3.11conda create -n rwkv-infctx python=3.11 pipconda activate rwkv-infctx# Install pytorchconda install -y pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia# We use python -m pip, instead of pip directly, as it resolve issues with venv not loading the right pippython -m pip install datasets transformers python -m pip install lightning==2.0.2 deepspeed==0.9.3 python -m pip install ninja numexpr jsonargparse 'jsonargparse[signatures]'python -m pip install lm-dataformat ftfy sentencepiece tokenizers wandb
Due to issues withdeepspeed on windows. Only linux environments are supported. WSl2 with windows is not recommended, due to heavy performance penalities in the process (cannot use deepspeed offload, ~50% slower)
- Either init a new model (todo script), or download an existing model
- Setup theconfig.yaml file, customized for your foundation model / finetune use case
- Preload the dataset using the
python3 preload_dataset.py {you-config}.yaml
- Start the training process
python3 new_train.py fit -c {your_config}.yaml
- Export the checkpoint after training is complete with
python3 export_checkpoint.py ../path/to/checkpoint
- From the checkpoint folder, you should find the fp32 model named
rwkv_model.pth
- You should probably convert this to an fp16 model (todo script)
@TODO
About
RWKV is a RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.
Resources
License
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Packages0
Uh oh!
There was an error while loading.Please reload this page.
Languages
- Python91.5%
- Cuda6.9%
- C++1.6%