Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

RWKV is a RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.

License

NotificationsYou must be signed in to change notification settings

Blealtan/RWKV-LM-LoRA

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

This branch contains my experimental attempts to achieve infinite context training in RWKV.With this implementation you can train on arbitrarily long context within (near) constant VRAM consumption; the increasing should be, take RWKV 7B as an example, about 2MB per 1024/2048 tokens (depending on your chosenctx_len) in the training sample, which will enable training on sequences over 1M tokens.Yet directly tune to such long sequences might be problematic; soctx_len_cutoff is provided so longer sequences are sliced into multiple pieces of the specified cutoff size and learnt by the model separately.It can be later increased until no cutoff presents.

The training code is by the way tremendously refactored into using PyTorch 2.0, Lightning 2.0 and DeepSpeed 2.0, and the starting script now relies on LightningCLI so you will see theconfig.yaml containing all the switches, mostly standard ones that Lightning processes by itself.

To use this repo, go intoRWKV-v4neo directory and do

python3 new_train.py fit -c {your_config}.yaml

Remember to modify the configuration for your own need.

SeeRWKV-v4neo/config-example.yaml for documentation on the various options

Existing limitations

The following features are not yet supported (that may exist inblinks original repo)

  • numpy file dataset
  • binidx dataset
  • model init weight
  • model resize weights (init from smaller to bigger model)
  • world tokenizer
  • Learning Rate init -> Learning Rate Final support
  • helper script to add new tokens to existing model

Environment setup

The following venv setup using conda, modify for your use case respectively

# ninja-build is required for the new trainersudo apt-get install ninja-build# Virtual env, with python 3.11conda create -n rwkv-infctx python=3.11 pipconda activate rwkv-infctx# Install pytorchconda install -y pytorch torchvision torchaudio pytorch-cuda=11.7 -c pytorch -c nvidia# We use python -m pip, instead of pip directly, as it resolve issues with venv not loading the right pippython -m pip install datasets transformers python -m pip install lightning==2.0.2 deepspeed==0.9.3 python -m pip install ninja numexpr jsonargparse 'jsonargparse[signatures]'python -m pip install lm-dataformat ftfy sentencepiece tokenizers wandb

Due to issues withdeepspeed on windows. Only linux environments are supported. WSl2 with windows is not recommended, due to heavy performance penalities in the process (cannot use deepspeed offload, ~50% slower)

Overall training process

  • Either init a new model (todo script), or download an existing model
  • Setup theconfig.yaml file, customized for your foundation model / finetune use case
  • Preload the dataset using thepython3 preload_dataset.py {you-config}.yaml
  • Start the training processpython3 new_train.py fit -c {your_config}.yaml
  • Export the checkpoint after training is complete withpython3 export_checkpoint.py ../path/to/checkpoint
  • From the checkpoint folder, you should find the fp32 model namedrwkv_model.pth
  • You should probably convert this to an fp16 model (todo script)

Examples of dataset configs

@TODO

About

RWKV is a RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python91.5%
  • Cuda6.9%
  • C++1.6%

[8]ページ先頭

©2009-2025 Movatter.jp