Train a model using TPU v5e

With a smaller 256-chip footprint per Pod, TPU v5e is optimized to be a highvalue product for transformer, text-to-image, and Convolutional Neural Network (CNN)training, fine-tuning, and serving. For more information about using Cloud TPUv5e for serving, seeInference using v5e.

For more information about Cloud TPU v5e TPU hardware and configurations, seeTPU v5e.

Get started

The following sections describe how to get started using TPU v5e.

Request quota

You need quota to use TPU v5e for training. There are different quota types foron-demand TPUs, reserved TPUs, and TPU Spot VMs. There are separatequotas required if you're using your TPU v5e forinference. For more information about quotas, seeQuotas. To request TPU v5e quota, contactCloudSales.

Create a Google Cloud account and project

You need a Google Cloud account and project to use Cloud TPU. For moreinformation, seeSet up a Cloud TPU environment.

Create a Cloud TPU

The best practice is to provision Cloud TPU v5es asqueued resourcesusing thequeued-resource create command. For more information, seeManage queued resources.

You can also use the Create Node API (gcloud compute tpus tpu-vm create) toprovision Cloud TPU v5es. For more information, seeManage TPU resources.

For more information about available v5e configurations for training, seeCloud TPU v5e types for training.

Framework setup

This section describes the general setup process for custom model training usingJAX or PyTorch with TPU v5e.

For inference setup instructions, seev5e inference introduction.

Define some environment variables:

exportPROJECT_ID=your_project_IDexportACCELERATOR_TYPE=v5litepod-16exportZONE=us-west4-aexportTPU_NAME=your_tpu_nameexportQUEUED_RESOURCE_ID=your_queued_resource_id

Setup for JAX

If you have slice shapes greater than 8 chips, you will have multiple VMs in oneslice. In this case, you need to use the--worker=all flag to run theinstallation on all TPU VMs in a single step without using SSH to log into eachseparately:

gcloudcomputetpustpu-vmssh${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--worker=all\--command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'

Command flag descriptions

VariableDescription
TPU_NAMEThe user-assigned text ID of the TPU which is created when the queued resource request is allocated.
PROJECT_IDGoogle Cloud Project Name. Use an existing project or create a new one at Set up your Google Cloud project
ZONESee theTPU regions and zones document for the supported zones.
workerThe TPU VM that has access to the underlying TPUs.

You can run the following command to check number of devices (the outputs shownhere were produced with a v5litepod-16 slice). This code tests that everythingis installed correctly by checking that JAX sees the Cloud TPU TensorCoresand can run basic operations:

gcloudcomputetpustpu-vmssh${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--worker=all\--command='python3 -c "import jax; print(jax.device_count()); print(jax.local_device_count())"'

The output will be similar to the following:

SSH:Attemptingtoconnecttoworker0...SSH:Attemptingtoconnecttoworker1...SSH:Attemptingtoconnecttoworker2...SSH:Attemptingtoconnecttoworker3...164164164164

jax.device_count() shows the total number of chips in the given slice.jax.local_device_count() indicates the count of chips accessible by a single VM in this slice.

# Check the number of chips in the given slice by summing the count of chips# from all VMs through the# jax.local_device_count() API call.gcloudcomputetpustpu-vmssh${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--worker=all\--command='python3 -c "import jax; xs=jax.numpy.ones(jax.local_device_count()); print(jax.pmap(lambda x: jax.lax.psum(x, \"i\"), axis_name=\"i\")(xs))"'

The output will be similar to the following:

SSH:Attemptingtoconnecttoworker0...SSH:Attemptingtoconnecttoworker1...SSH:Attemptingtoconnecttoworker2...SSH:Attemptingtoconnecttoworker3...[16.16.16.16.][16.16.16.16.][16.16.16.16.][16.16.16.16.]

Try theJAX Tutorials in this document to get started withv5e training using JAX.

Setup for PyTorch

Note that v5e only supports thePJRT runtimeand PyTorch 2.1+ will use PJRT as the default runtime for all TPU versions.

This section describes how to start using PJRT on v5e with PyTorch/XLA withcommands for all workers.

Install dependencies

gcloudcomputetpustpu-vmssh${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--worker=all\--command='      sudo apt-get update -y      sudo apt-get install libomp5 -y      pip install mkl mkl-include      pip install tf-nightly tb-nightly tbp-nightly      pip install numpy      sudo apt-get install libopenblas-dev -y      pip install torch~=PYTORCH_VERSION torchvision torch_xla[tpu]~=PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

ReplacePYTORCH_VERSION with the version of PyTorch you want to use.PYTORCH_VERSION is used to specify the same version for PyTorch/XLA. 2.6.0is recommended.

For more information about versions of PyTorch and PyTorch/XLA, seePyTorch - Get Started andPyTorch/XLA releases.

For more information on installing PyTorch/XLA, seePyTorch/XLA installation.

If you get an error when installing the wheels fortorch,torch_xla, ortorchvision likepkg_resources.extern.packaging.requirements.InvalidRequirement: Expected endor semicolon (after name and no valid version specifier) torch==nightly+20230222,downgrade your version with this command:

pip3installsetuptools==62.1.0

Run a script with PJRT

Note: For models which have sizable, frequent allocations, usingtcmalloccan significantly improve training time compared to the defaultmallocimplementation, so the defaultmalloc used on a TPU VM istcmalloc. However, depending on your workload (for example, withDLRM which has very large allocations for its embedding tables)tcmallocmay cause a slowdown. In this case you might try to unset the following variableusing the defaultmalloc instead:
unsetLD_PRELOAD

The following is an example using a Python script to do a calculationon a v5e VM:

gcloudcomputetpustpu-vmssh${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--worker=all\--command='      export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$HOME/.local/lib/      export PJRT_DEVICE=TPU      export PT_XLA_DEBUG=0      export USE_TORCH=ON      unset LD_PRELOAD      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so      python3 -c "import torch; import torch_xla; import torch_xla.core.xla_model as xm; print(xm.xla_device()); dev = xm.xla_device(); t1 = torch.randn(3,3,device=dev); t2 = torch.randn(3,3,device=dev); print(t1 + t2)"'

This generates output similar to the following:

SSH:Attemptingtoconnecttoworker0...SSH:Attemptingtoconnecttoworker1...xla:0tensor([[1.8611,-0.3114,-2.4208],[-1.0731,0.3422,3.1445],[0.5743,0.2379,1.1105]],device='xla:0')xla:0tensor([[1.8611,-0.3114,-2.4208],[-1.0731,0.3422,3.1445],[0.5743,0.2379,1.1105]],device='xla:0')

Try thePyTorch Tutorials in this document to get started withv5e training using PyTorch.

Delete your TPU and queued resource at the end of your session. To delete aqueued resource, delete the slice and then the queued resource in 2 steps:

gcloudcomputetpustpu-vmdelete${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--quietgcloudcomputetpusqueued-resourcesdelete${QUEUED_RESOURCE_ID}\--project=${PROJECT_ID}\--zone=${ZONE}\--quiet

These two steps can also be used to remove queued resource requests that are intheFAILED state.

JAX/FLAX examples

The following sections describe examples of how to train JAX and FLAX models onTPU v5e.

Train ImageNet on v5e

This tutorial describes how to train ImageNet on v5e using fake input data. Ifyou want to use real data, refer to theREADME file on GitHub.

Set up

  1. Create environment variables:

    exportPROJECT_ID=your-project-idexportTPU_NAME=your-tpu-nameexportZONE=us-west4-aexportACCELERATOR_TYPE=v5litepod-8exportRUNTIME_VERSION=v2-alpha-tpuv5-liteexportSERVICE_ACCOUNT=your-service-accountexportQUEUED_RESOURCE_ID=your-queued-resource-id

    Environment variable descriptions

    VariableDescription
    PROJECT_ID Your Google Cloud project ID. Use an existing project orcreate a new one.
    TPU_NAMEThe name of the TPU.
    ZONE The zone in which to create the TPU VM. For more information about supported zones, seeTPU regions and zones.
    ACCELERATOR_TYPE The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, seeTPU versions.
    RUNTIME_VERSIONThe Cloud TPUsoftware version.
    SERVICE_ACCOUNT The email address for your service account. You can find it by going to theService Accounts page in the Google Cloud console.

    For example:tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_IDThe user-assigned text ID of the queued resource request.

  2. Create a TPU resource:

    gcloudcomputetpusqueued-resourcescreate${QUEUED_RESOURCE_ID}\--node-id=${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--accelerator-type=${ACCELERATOR_TYPE}\--runtime-version=${RUNTIME_VERSION}\--service-account=${SERVICE_ACCOUNT}
    Note: To use a reservation, add the--reserved flag. To use TPUSpot VMs, add the--spot flag.

    You will be able to SSH to your TPU VM once your queued resource is intheACTIVE state:

    gcloudcomputetpusqueued-resourcesdescribe${QUEUED_RESOURCE_ID}\--project=${PROJECT_ID}\--zone=${ZONE}

    When the QueuedResource is in theACTIVE state, the output willbe similar to the following:

    state:ACTIVE
  3. Install newest version of JAX and jaxlib:

    gcloudcomputetpustpu-vmssh${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--worker=all\--command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
    Note: If you see an error suggesting you use jax>=0.4.8, you can safely ignorethe message.
  4. Clone the ImageNet model and install the corresponding requirements:

    gcloudcomputetpustpu-vmssh${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--worker=all\--command="git clone https://github.com/coolkp/flax.git && cd flax && git checkout pmap-orbax-conversion && git pull"
  5. To generate fake data, the model needs information on the dimensions ofthe dataset. This can be gathered from the ImageNet dataset's metadata:

    gcloudcomputetpustpu-vmssh${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--worker=all\--command="cd flax/examples/imagenet && pip install -r requirements-cloud-tpu.txt"

Train the model

Once all the previous steps are done, you can train the model.

gcloudcomputetpustpu-vmssh${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--worker=all\--command="cd flax/examples/imagenet && bash ../../tests/download_dataset_metadata.sh && JAX_PLATFORMS=tpu python imagenet_fake_data_benchmark.py"

Delete the TPU and queued resource

Delete your TPU and queued resource at the end of your session.

gcloudcomputetpustpu-vmdelete${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--quiet
gcloudcomputetpusqueued-resourcesdelete${QUEUED_RESOURCE_ID}\--project=${PROJECT_ID}\--zone=${ZONE}\--quiet

Hugging Face FLAX Models

Hugging Face models implemented in FLAX work out ofthe box on Cloud TPU v5e. This section provides instructions for runningpopular models.

Train ViT on Imagenette

This tutorial shows you how to train theVision Transformer(ViT) model from HuggingFace using the Fast AIImagenettedataset on Cloud TPU v5e.

The ViT model was the first one that successfully trained a Transformer encoderon ImageNet with excellent results compared to convolutional networks. For moreinformation, seeViT overview.

Set up

  1. Create environment variables:

    exportPROJECT_ID=your-project-idexportTPU_NAME=your-tpu-nameexportZONE=us-west4-aexportACCELERATOR_TYPE=v5litepod-16exportRUNTIME_VERSION=v2-alpha-tpuv5-liteexportSERVICE_ACCOUNT=your-service-accountexportQUEUED_RESOURCE_ID=your-queued-resource-id

    Environment variable descriptions

    VariableDescription
    PROJECT_ID Your Google Cloud project ID. Use an existing project orcreate a new one.
    TPU_NAMEThe name of the TPU.
    ZONE The zone in which to create the TPU VM. For more information about supported zones, seeTPU regions and zones.
    ACCELERATOR_TYPE The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, seeTPU versions.
    RUNTIME_VERSIONThe Cloud TPUsoftware version.
    SERVICE_ACCOUNT The email address for your service account. You can find it by going to theService Accounts page in the Google Cloud console.

    For example:tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_IDThe user-assigned text ID of the queued resource request.

  2. Create a TPU resource:

    gcloudcomputetpusqueued-resourcescreate${QUEUED_RESOURCE_ID}\--node-id=${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--accelerator-type=${ACCELERATOR_TYPE}\--runtime-version=${RUNTIME_VERSION}\--service-account=${SERVICE_ACCOUNT}
    Note: To use a reservation, add the--reserved flag. To use TPUSpot VMs, add the--spot flag.

    You will be able to SSH to your TPU VM once your queued resourceis in stateACTIVE:

    gcloudcomputetpusqueued-resourcesdescribe${QUEUED_RESOURCE_ID}\--project=${PROJECT_ID}\--zone=${ZONE}

    When the queued resource is in theACTIVE state, the output will be similarto the following:

    state:ACTIVE
  3. Install JAX and its library:

    gcloudcomputetpustpu-vmssh${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--worker=all\--command='pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
  4. Download Hugging Facerepositoryand install requirements:

    gcloudcomputetpustpu-vmssh${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--worker=all\--command='git clone https://github.com/huggingface/transformers.git && cd transformers && pip install . && pip install -r examples/flax/_tests_requirements.txt && pip install --upgrade huggingface-hub urllib3 zipp && pip install tensorflow==2.19 && sed -i 's/torchvision==0.12.0+cpu/torchvision==0.22.1/' examples/flax/vision/requirements.txt && pip install -r examples/flax/vision/requirements.txt && pip install tf-keras'
  5. Download the Imagenette dataset:

    gcloudcomputetpustpu-vmssh${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--worker=all\--command='cd transformers && wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz && tar -xvzf imagenette2.tgz'

Train the model

Train the model with a pre-mapped buffer at 4GB.

gcloudcomputetpustpu-vmssh${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--worker=all\--command='cd transformers && JAX_PLATFORMS=tpu python3 examples/flax/vision/run_image_classification.py --train_dir "imagenette2/train" --validation_dir "imagenette2/val" --output_dir "./vit-imagenette" --learning_rate 1e-3 --preprocessing_num_workers 32 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --model_name_or_path google/vit-base-patch16-224-in21k --num_train_epochs 3'

Delete the TPU and queued resource

Delete your TPU and queued-resource at the end of your session.

gcloudcomputetpustpu-vmdelete${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--quietgcloudcomputetpusqueued-resourcesdelete${QUEUED_RESOURCE_ID}\--project=${PROJECT_ID}\--zone=${ZONE}\--quiet

ViT benchmarking results

The training script was run on v5litepod-4, v5litepod-16, and v5litepod-64. Thefollowing table shows the throughputs with different accelerator types.

Accelerator typev5litepod-4v5litepod-16v5litepod-64
Epoch333
Global batch size32128512
Throughput (examples/sec)263.40429.34470.71

Train Diffusion on Pokémon

This tutorial shows you how to train the Stable Diffusion model fromHuggingFace using thePokémondataset on Cloud TPU v5e.

The Stable Diffusion model is a latent text-to-image model that generatesphoto-realistic images from any text input. For more information, see thefollowing resources:

Set up

  1. Set an environment variable for the name of your storage bucket:

    exportGCS_BUCKET_NAME=your_bucket_name
  2. Set up a storage bucket for your model output:

    gcloudstoragebucketscreategs://GCS_BUCKET_NAME\--project=your_project\--location=us-west1
  3. Create environment variables:

    exportPROJECT_ID=your-project-idexportTPU_NAME=your-tpu-nameexportZONE=us-west1-cexportACCELERATOR_TYPE=v5litepod-16exportRUNTIME_VERSION=v2-alpha-tpuv5-liteexportSERVICE_ACCOUNT=your-service-accountexportQUEUED_RESOURCE_ID=your-queued-resource-id

    Environment variable descriptions

    VariableDescription
    PROJECT_ID Your Google Cloud project ID. Use an existing project orcreate a new one.
    TPU_NAMEThe name of the TPU.
    ZONE The zone in which to create the TPU VM. For more information about supported zones, seeTPU regions and zones.
    ACCELERATOR_TYPE The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, seeTPU versions.
    RUNTIME_VERSIONThe Cloud TPUsoftware version.
    SERVICE_ACCOUNT The email address for your service account. You can find it by going to theService Accounts page in the Google Cloud console.

    For example:tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_IDThe user-assigned text ID of the queued resource request.

  4. Create a TPU resource:

    gcloudcomputetpusqueued-resourcescreate${QUEUED_RESOURCE_ID}\--node-id=${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--accelerator-type=${ACCELERATOR_TYPE}\--runtime-version=${RUNTIME_VERSION}\--service-account=${SERVICE_ACCOUNT}
    Note: To use a reservation, add the--reserved flag. To use TPUSpot VMs, add the--spot flag.

    You will be able to SSH to your TPU VM once your queued resource is in theACTIVE state:

    gcloudcomputetpusqueued-resourcesdescribe${QUEUED_RESOURCE_ID}\--project=${PROJECT_ID}\--zone=${ZONE}

    When the queued resource is in theACTIVE state, the output will be similar to the following:

    state:ACTIVE
  5. Install JAX and its library.

    gcloudcomputetpustpu-vmssh${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--worker=all\--command='pip install "jax[tpu]==0.4.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html'
  6. Download the HuggingFacerepositoryand install requirements.

    gcloudcomputetpustpu-vmssh${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--worker=all\--command='git clone https://github.com/RissyRan/diffusers.git && cd diffusers && pip install . && pip install -U -r examples/text_to_image/requirements_flax.txt && pip install tensorflow==2.17.1 clu && pip install tensorboard==2.17.1'

Train the model

Train the model with a pre-mapped buffer at 4GB.

gcloudcomputetpustpu-vmssh${TPU_NAME}--zone=${ZONE}--project=${PROJECT_ID}--worker=all--command="    git clone https://github.com/google/maxdiffusion    cd maxdiffusion    pip3 install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html    pip3 install -r requirements.txt    pip3 install .    pip3 install gcsfs    export LIBTPU_INIT_ARGS=''    python -m src.maxdiffusion.train src/maxdiffusion/configs/base_2_base.yml run_name=my_run \    jax_cache_dir=gs://${GCS_BUCKET_NAME} activations_dtype=bfloat16 weights_dtype=bfloat16 \    per_device_batch_size=1 precision=DEFAULT dataset_save_location=gs://${GCS_BUCKET_NAME} \    output_dir=gs://${GCS_BUCKET_NAME}/ attention=flash"

Clean up

Delete your TPU, queued resource, and Cloud Storage bucket at the end of yoursession.

  1. Delete your TPU:

    gcloudcomputetpustpu-vmdelete${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--quiet
  2. Delete the queued resource:

    gcloudcomputetpusqueued-resourcesdelete${QUEUED_RESOURCE_ID}\--project=${PROJECT_ID}\--zone=${ZONE}\--quiet
  3. Delete the Cloud Storage bucket:

    gcloudstoragerm-rgs://${GCS_BUCKET_NAME}

Benchmarking results for diffusion

The training script ran on v5litepod-4, v5litepod-16, and v5litepod-64.The following table shows the throughputs.

Accelerator typev5litepod-4v5litepod-16v5litepod-64
Train Step150015001500
Global batch size3264128
Throughput (examples/sec)36.5343.7149.36

PyTorch/XLA

The following sections describe examples of how to train PyTorch/XLA models onTPU v5e.

Train ResNet using the PJRT runtime

PyTorch/XLA is migrating from XRT to PjRt from PyTorch 2.0+. Here are theupdated instructions to set up v5e for PyTorch/XLA training workloads.

Set up
  1. Create environment variables:

    exportPROJECT_ID=your-project-idexportTPU_NAME=your-tpu-nameexportZONE=us-west4-aexportACCELERATOR_TYPE=v5litepod-16exportRUNTIME_VERSION=v2-alpha-tpuv5-liteexportSERVICE_ACCOUNT=your-service-accountexportQUEUED_RESOURCE_ID=your-queued-resource-id

    Environment variable descriptions

    VariableDescription
    PROJECT_ID Your Google Cloud project ID. Use an existing project orcreate a new one.
    TPU_NAMEThe name of the TPU.
    ZONE The zone in which to create the TPU VM. For more information about supported zones, seeTPU regions and zones.
    ACCELERATOR_TYPE The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, seeTPU versions.
    RUNTIME_VERSIONThe Cloud TPUsoftware version.
    SERVICE_ACCOUNT The email address for your service account. You can find it by going to theService Accounts page in the Google Cloud console.

    For example:tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_IDThe user-assigned text ID of the queued resource request.

  2. Create a TPU resource:

    gcloudcomputetpusqueued-resourcescreate${QUEUED_RESOURCE_ID}\--node-id=${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--accelerator-type=${ACCELERATOR_TYPE}\--runtime-version=${RUNTIME_VERSION}\--service-account=${SERVICE_ACCOUNT}
    Note: To use a reservation, add the--reserved flag. To use TPUSpot VMs, add the--spot flag.

    You will be able to SSH to your TPU VM once your QueuedResourceis inACTIVE state:

    gcloudcomputetpusqueued-resourcesdescribe${QUEUED_RESOURCE_ID}\--project=${PROJECT_ID}\--zone=${ZONE}

    When the queued resource is in theACTIVE state, the output will be similar to the following:

    state:ACTIVE
  3. Install Torch/XLA specific dependencies

    gcloudcomputetpustpu-vmssh${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--worker=all\--command='     sudo apt-get update -y     sudo apt-get install libomp5 -y     pip3 install mkl mkl-include     pip3 install tf-nightly tb-nightly tbp-nightly     pip3 install numpy     sudo apt-get install libopenblas-dev -y     pip install torch==PYTORCH_VERSION torchvision torch_xla[tpu]==PYTORCH_VERSION -f https://storage.googleapis.com/libtpu-releases/index.html -f https://storage.googleapis.com/libtpu-wheels/index.html'

    ReplacePYTORCH_VERSION with the version of PyTorch you want to use.PYTORCH_VERSION is used to specify the same version for PyTorch/XLA. 2.6.0is recommended.

    For more information about versions of PyTorch and PyTorch/XLA, seePyTorch - Get Started andPyTorch/XLA releases.

    For more information on installing PyTorch/XLA, seePyTorch/XLA installation.

Train the ResNet model
gcloudcomputetpustpu-vmssh${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--worker=all\--command='      date      export PJRT_DEVICE=TPU      export PT_XLA_DEBUG=0      export USE_TORCH=ON      export XLA_USE_BF16=1      export LIBTPU_INIT_ARGS=--xla_jf_auto_cross_replica_sharding      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so      git clone https://github.com/pytorch/xla.git      cd xla/      git checkout release-r2.6      python3 test/test_train_mp_imagenet.py --model=resnet50  --fake_data --num_epochs=1 --num_workers=16  --log_steps=300 --batch_size=64 --profile'

Delete the TPU and queued resource

Delete your TPU and queued resource at the end of your session.

gcloudcomputetpustpu-vmdelete${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--quietgcloudcomputetpusqueued-resourcesdelete${QUEUED_RESOURCE_ID}\--project=${PROJECT_ID}\--zone=${ZONE}\--quiet
Benchmark result

The following table shows the benchmark throughputs.

Accelerator typeThroughput (examples/second)
v5litepod-44240 ex/s
v5litepod-1610,810 ex/s
v5litepod-6446,154 ex/s

Train ViT on v5e

This tutorial will cover how to run VIT on v5e using the HuggingFacerepositoryon PyTorch/XLA on thecifar10 dataset.

Set up

  1. Create environment variables:

    exportPROJECT_ID=your-project-idexportTPU_NAME=your-tpu-nameexportZONE=us-west4-aexportACCELERATOR_TYPE=v5litepod-16exportRUNTIME_VERSION=v2-alpha-tpuv5-liteexportSERVICE_ACCOUNT=your-service-accountexportQUEUED_RESOURCE_ID=your-queued-resource-id

    Environment variable descriptions

    VariableDescription
    PROJECT_ID Your Google Cloud project ID. Use an existing project orcreate a new one.
    TPU_NAMEThe name of the TPU.
    ZONE The zone in which to create the TPU VM. For more information about supported zones, seeTPU regions and zones.
    ACCELERATOR_TYPE The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, seeTPU versions.
    RUNTIME_VERSIONThe Cloud TPUsoftware version.
    SERVICE_ACCOUNT The email address for your service account. You can find it by going to theService Accounts page in the Google Cloud console.

    For example:tpu-service-account@PROJECT_ID.iam.gserviceaccount.com

    QUEUED_RESOURCE_IDThe user-assigned text ID of the queued resource request.

  2. Create a TPU resource:

    gcloudcomputetpusqueued-resourcescreate${QUEUED_RESOURCE_ID}\--node-id=${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--accelerator-type=${ACCELERATOR_TYPE}\--runtime-version=${RUNTIME_VERSION}\--service-account=${SERVICE_ACCOUNT}
    Note: To use a reservation, add the--reserved flag. To use TPUSpot VMs, add the--spot flag.

    You will be able to SSH to your TPU VM once your QueuedResource is intheACTIVE state:

    gcloudcomputetpusqueued-resourcesdescribe${QUEUED_RESOURCE_ID}\--project=${PROJECT_ID}\--zone=${ZONE}

    When the queued resource is in theACTIVE state, the output willbe similar to the following:

    state:ACTIVE
  3. Install PyTorch/XLA dependencies

    gcloudcomputetpustpu-vmssh${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--worker=all\--command='sudoapt-getupdate-ysudoapt-getinstalllibomp5-ypip3installmklmkl-includepip3installtf-nightlytb-nightlytbp-nightlypip3installnumpysudoapt-getinstalllibopenblas-dev-ypipinstalltorch==PYTORCH_VERSIONtorchvisiontorch_xla[tpu]==PYTORCH_VERSION-fhttps://storage.googleapis.com/libtpu-releases/index.html-fhttps://storage.googleapis.com/libtpu-wheels/index.htmlpipinstalljax==0.4.38jaxlib==0.4.38-ihttps://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/

    ReplacePYTORCH_VERSION with the version of PyTorch you want to use.PYTORCH_VERSION is used to specify the same version for PyTorch/XLA. 2.6.0is recommended.

    For more information about versions of PyTorch and PyTorch/XLA, seePyTorch - Get Started andPyTorch/XLA releases.

    For more information on installing PyTorch/XLA, seePyTorch/XLA installation.

  4. Download HuggingFacerepositoryand install requirements.

    gcloudcomputetpustpu-vmssh${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--worker=all\--command="      git clone https://github.com/suexu1025/transformers.git vittransformers; \      cd vittransformers; \      pip3 install .; \      pip3 install datasets; \      wget https://github.com/pytorch/xla/blob/master/scripts/capture_profile.py"

Train the model

gcloudcomputetpustpu-vmssh${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--worker=all\--command='      export PJRT_DEVICE=TPU      export PT_XLA_DEBUG=0      export USE_TORCH=ON      export TF_CPP_MIN_LOG_LEVEL=0      export XLA_USE_BF16=1      export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH      export TPU_LIBRARY_PATH=$HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so      cd vittransformers      python3 -u examples/pytorch/xla_spawn.py --num_cores 4 examples/pytorch/image-pretraining/run_mae.py --dataset_name=cifar10 \      --remove_unused_columns=False \      --label_names=pixel_values \      --mask_ratio=0.75 \      --norm_pix_loss=True \      --do_train=true \      --do_eval=true \      --base_learning_rate=1.5e-4 \      --lr_scheduler_type=cosine \      --weight_decay=0.05 \      --num_train_epochs=3 \      --warmup_ratio=0.05 \      --per_device_train_batch_size=8 \      --per_device_eval_batch_size=8 \      --logging_strategy=steps \      --logging_steps=30 \      --evaluation_strategy=epoch \      --save_strategy=epoch \      --load_best_model_at_end=True \      --save_total_limit=3 \      --seed=1337 \      --output_dir=MAE \      --overwrite_output_dir=true \      --logging_dir=./tensorboard-metrics \      --tpu_metrics_debug=true'

Delete the TPU and queued resource

Delete your TPU and queued resource at the end of your session.

gcloudcomputetpustpu-vmdelete${TPU_NAME}\--project=${PROJECT_ID}\--zone=${ZONE}\--quietgcloudcomputetpusqueued-resourcesdelete${QUEUED_RESOURCE_ID}\--project=${PROJECT_ID}\--zone=${ZONE}\--quiet

Benchmark result

The following table shows the benchmark throughputs for different accelerator types.

v5litepod-4v5litepod-16v5litepod-64
Epoch333
Global batch size32128512
Throughput (examples/sec)2016572,844

Except as otherwise noted, the content of this page is licensed under theCreative Commons Attribution 4.0 License, and code samples are licensed under theApache 2.0 License. For details, see theGoogle Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2025-12-17 UTC.