- Notifications
You must be signed in to change notification settings - Fork32
License
google/saxml
Folders and files
| Name | Name | Last commit message | Last commit date | |
|---|---|---|---|---|
Repository files navigation
Saxml is an experimental system that servesPaxml,JAX,andPyTorch models for inference.
A Sax cell (aka Sax cluster) consists of an admin server and a group of modelservers. The admin server keeps track of model servers, assigns published modelsto model servers to serve, and helps clients locate model servers servingspecific published models.
The example below walks through setting up a Sax cell and starting a TPU or GPUmodel server in the cell.
Install thegcloud CLI and set the default account and project:
gcloud config set account <your-email-account>gcloud config set project <your-project>Create aCloud Storage bucket:
GSBUCKET=sax-datagcloud storage buckets create gs://${GSBUCKET}Create aCompute Engine VM instance:
gcloud compute instances create sax-admin \ --zone=us-central1-b \ --machine-type=e2-standard-8 \ --boot-disk-size=200GB \ --scopes=https://www.googleapis.com/auth/cloud-platformUse thisguide toenable the Cloud TPU API in a Google Cloud project.
Create a Cloud TPU VM instance:
gcloud compute tpus tpu-vm create sax-tpu \ --zone=us-central2-b \ --accelerator-type=v4-8 \ --version=tpu-vm-v4-base \ --scopes=https://www.googleapis.com/auth/cloud-platformAlternatively or in addition to the Cloud TPU VM instance, create aCompute Engine VM instance with GPUs:
gcloud compute instances create sax-gpu \ --zone=us-central1-b \ --machine-type=n1-standard-32 \ --accelerator=count=4,type=nvidia-tesla-v100 \ --maintenance-policy=TERMINATE \ --boot-disk-size=200GB \ --scopes=https://www.googleapis.com/auth/cloud-platformConsidercreatinga VM instance using the "GPU-optimized Debian 10 with CUDA 11.0" image instead,so the Nvidia CUDA stack doesn't need to be manually installedas described below.
SSH to the Compute Engine VM instance:
gcloud compute ssh --zone=us-central1-b sax-adminInside the VM instance, clone the Sax repo and initialize the environment:
git clone https://github.com/google/saxml.gitcd saxmlsaxml/tools/init_cloud_vm.shConfigure the Sax admin server. This only needs to be done once:
bazel run saxml/bin:admin_config -- \ --sax_cell=/sax/test \ --sax_root=gs://${GSBUCKET}/sax-root \ --fs_root=gs://${GSBUCKET}/sax-fs-root \ --alsologtostderrStart the Sax admin server:
bazel run saxml/bin:admin_server -- \ --sax_cell=/sax/test \ --sax_root=gs://${GSBUCKET}/sax-root \ --port=10000 \ --alsologtostderrSSH to the Cloud TPU VM instance:
gcloud compute tpus tpu-vm ssh --zone=us-central2-b sax-tpuInside the VM instance, clone the Sax repo and initialize the environment:
git clone https://github.com/google/saxml.gitcd saxmlsaxml/tools/init_cloud_vm.shStart the Sax model server:
SAX_ROOT=gs://${GSBUCKET}/sax-root \bazel run saxml/server:server -- \ --sax_cell=/sax/test \ --port=10001 \ --platform_chip=tpuv4 \ --platform_topology=2x2x1 \ --alsologtostderrYou should see a log message "Joined [admin server IP:port]" from the modelserver to indicate it has successfully joined the admin server.
SSH to the Compute Engine VM instance:
gcloud compute ssh --zone=us-central1-b sax-gpuInstall theNvidia GPU driver,CUDA, andcuDNN.Note that Sax by default requires CUDA 11. To switch to CUDA 12,editrequirements-cuda.txt and replacejaxlib==0.4.7+cuda11.cudnn86 withjaxlib==0.4.7+cuda12.cudnn88.
Inside the VM instance, clone the Sax repo and initialize the environment:
git clone https://github.com/google/saxml.gitcd saxmlsaxml/tools/init_cloud_vm.shEnable the GPU-specificrequirements.txt file:
cp requirements-cuda.txt requirements.txtStart the Sax model server:
SAX_ROOT=gs://${GSBUCKET}/sax-root \bazel run saxml/server:server -- \ --sax_cell=/sax/test \ --port=10001 \ --platform_chip=v100 \ --platform_topology=4 \ --jax_platforms=cuda \ --alsologtostderrYou should see a log message "Joined [admin server IP:port]" from the modelserver to indicate it has successfully joined the admin server.
Sax comes with a command-line tool calledsaxutil for easy usage:
# From the `saxml` repo root directory:alias saxutil='bazel run saxml/bin:saxutil -- --sax_root=gs://${GSBUCKET}/sax-root'saxutil supports the following commands:
saxutil help: Show general help or help about a particular command.saxutil ls: List all cells, all models in a cell, or a particular model.saxutil publish: Publish a model.saxutil unpublish: Unpublish a model.saxutil update: Update a model.saxutil lm.generate: Use a language model generate suffixes from a prefix.saxutil lm.score: Use a language model to score a prefix and suffix.saxutil lm.embed: Use a language model to embed text into a vector.saxutil vm.generate: Use a vision model to generate images from text.saxutil vm.classify: Use a vision model to classify an image.saxutil vm.embed: Use a vision model to embed an image into a vector.
As an example, Sax comes with a Pax language model servable on a Cloud TPU VMv4-8 instance. You can use it to verify Sax is correctly set up by publishingand using the model with a dummy checkpoint.
saxutil publish \ /sax/test/lm2b \ saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd2BTest \ None \ 1Check if the model is loaded by looking at the "selected replica address"column of this command's output:
saxutil ls /sax/test/lm2bWhen the model is loaded, issue a query:
saxutil lm.generate /sax/test/lm2b "Q: Who is Harry Porter's mother? A: "The result will be printed in the terminal.
To use a real checkpoint with the model, follow thePaxml tutorial to generate a checkpoint. Themodel can then be published in Sax like this:
saxutil publish \ /sax/test/lm2b \ saxml.server.pax.lm.params.lm_cloud.LmCloudSpmd2B \ gs://${GSBUCKET}/checkpoints/checkpoint_00000000 \ 1Use the samesaxutil lm.generate command as above to query the model.
First get LLaMA pytorch_vars from Meta, then run the following script toconvert the LLaMA PyTorch checkpoint to SAX format
python3 -m saxml/tools/convert_llama_ckpt --base llama_7b --pax pax_7bFor the 7B model, this script roughly needs 50-60GB memory. For larger models,for example, the 70B model, this script would need 500-600GB memory to run.
The script load and save weights in a single pass. To fit less memory,modify convert() function to load/save weights in multiple passes.In each pass, load and save partial weights (subset of all weight variables).
After converting the checkpoint, the checkpoint folder should have the following structure
checkpoint_00000000metadata/metadatastate/mdl_vars.params.lm*/......step/Please create empty files “commit_success.txt” and put one in each folder.This will let SAX know this checkpoint is ready to use when loading the model.So the fully ready checkpoint should be as following:
checkpoint_00000000commit_success.txtmetadata/commit_success.txtmetadatastate/commit_success.txtmdl_vars.params.lm*/......step/Now the checkpoint is fully ready.
Then start the SAX server
GPU server:
SAX_ROOT=gs://${GSBUCKET}/sax-root \bazel run saxml/server:server -- \ --sax_cell=/sax/test \ --port=10001 \ --platform_chip=a100 \ --platform_topology=1 \ --jax_platforms=cuda \ --alsologtostderrTPU server:
SAX_ROOT=gs://${GSBUCKET}/sax-root \bazel run saxml/server:server -- \ --sax_cell=/sax/test \ --port=10001 \ --platform_chip=tpuv4 \ --platform_topology=2x2x1 \ --alsologtostderrFinally move the converted ckpt to your google cloud data bucket and publishthe model
7B model
saxutil publish \ /sax/test/llama-7b \ saxml.server.pax.lm.params.lm_cloud.LLaMA7BFP16 \ gs://sax-data/pax-llama/7B \ 170B model
saxutil publish \ /sax/test/llama-7b \ saxml.server.pax.lm.params.lm_cloud.LLaMA70BFP16TPUv5e \ gs://sax-data/pax-llama/70B \ 1About
Resources
License
Code of conduct
Contributing
Security policy
Uh oh!
There was an error while loading.Please reload this page.
Stars
Watchers
Forks
Packages0
Uh oh!
There was an error while loading.Please reload this page.