Run a calculation on a Cloud TPU VM using JAX
This document provides a brief introduction to working with JAX and Cloud TPU.
Note: This example shows how to run code on a v5litepod-8 (v5e) TPU which is asingle-host TPU. Single-host TPUs have only 1 TPU VM. To run code on TPUs withmore than one TPU VM (for example, v5litepod-16 or larger), seeRun JAX code on Cloud TPU slices.Before you begin
Before running the commands in this document, you must create a Google Cloudaccount, install the Google Cloud CLI, and configure thegcloud command. Formore information, seeSet up the Cloud TPU environment.
Create a Cloud TPU VM usinggcloud
Define some environment variables to make commands easier to use.
exportPROJECT_ID=your-project-idexportTPU_NAME=your-tpu-nameexportZONE=us-east5-aexportACCELERATOR_TYPE=v5litepod-8exportRUNTIME_VERSION=v2-alpha-tpuv5-lite
Environment variable descriptions
Variable Description PROJECT_IDYour Google Cloud project ID. Use an existing project orcreate a new one. TPU_NAMEThe name of the TPU. ZONEThe zone in which to create the TPU VM. For more information about supported zones, seeTPU regions and zones. ACCELERATOR_TYPEThe accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, seeTPU versions. RUNTIME_VERSIONThe Cloud TPUsoftware version. Create your TPU VM by running the following command from a Cloud Shell oryour computer terminal where theGoogle Cloud CLIis installed.
$gcloudcomputetpustpu-vmcreate$TPU_NAME\--project=$PROJECT_ID\--zone=$ZONE\--accelerator-type=$ACCELERATOR_TYPE\--version=$RUNTIME_VERSION
Connect to your Cloud TPU VM
Connect to your TPU VM over SSH by using the following command:
$gcloudcomputetpustpu-vmssh$TPU_NAME\--project=$PROJECT_ID\--zone=$ZONE
If you fail to connect to a TPU VM using SSH, it might be because the TPU VMdoesn't have an external IP address. To access a TPU VM without an external IPaddress, follow the instructions inConnect to a TPU VM without a public IPaddress.
Install JAX on your Cloud TPU VM
(vm)$pipinstalljax[tpu]-fhttps://storage.googleapis.com/jax-releases/libtpu_releases.html
System check
Verify that JAX can access the TPU and can run basic operations:
Start the Python 3 interpreter:
(vm)$python3>>>importjaxDisplay the number of TPU cores available:
>>>jax.device_count()
The number of TPU cores is displayed. The number of cores displayed is dependenton the TPU version you are using. For more information, seeTPU versions.
Perform a calculation
>>>jax.numpy.add(1,1)
The result of the numpy add is displayed:
Output from the command:
Array(2,dtype=int32,weak_type=True)
Exit the Python interpreter
>>>exit()
Running JAX code on a TPU VM
You can now run any JAX code you want. TheFlax examplesare a great place to start with running standard ML models in JAX. For example,to train a basic MNIST convolutional network:
Install Flax examples dependencies:
(vm)$pipinstall--upgradeclu(vm)$pipinstalltensorflow(vm)$pipinstalltensorflow_datasets
Install Flax:
(vm)$gitclonehttps://github.com/google/flax.git(vm)$pipinstall--userflax
Run the Flax MNIST training script:
(vm)$cdflax/examples/mnist(vm)$python3main.py--workdir=/tmp/mnist\--config=configs/default.py\--config.learning_rate=0.05\--config.num_epochs=5
The script downloads the dataset and starts training. The script output shouldlook like this:
I021418:00:50.660087140369022753856train.py:146]epoch:1,train_loss:0.2421,train_accuracy:92.97,test_loss:0.0615,test_accuracy:97.88I021418:00:52.015867140369022753856train.py:146]epoch:2,train_loss:0.0594,train_accuracy:98.16,test_loss:0.0412,test_accuracy:98.72I021418:00:53.377511140369022753856train.py:146]epoch:3,train_loss:0.0418,train_accuracy:98.72,test_loss:0.0296,test_accuracy:99.04I021418:00:54.727168140369022753856train.py:146]epoch:4,train_loss:0.0305,train_accuracy:99.06,test_loss:0.0257,test_accuracy:99.15I021418:00:56.082807140369022753856train.py:146]epoch:5,train_loss:0.0252,train_accuracy:99.20,test_loss:0.0263,test_accuracy:99.18
Clean up
To avoid incurring charges to your Google Cloud account for the resources used on this page, follow these steps.
When you are done with your TPU VM, follow these steps to clean up your resources.
Disconnect from the Cloud TPU instance, if you have not already done so:
(vm)$exit
Your prompt should now be username@projectname, showing you are in the Cloud Shell.
Delete your Cloud TPU:
$gcloudcomputetpustpu-vmdelete$TPU_NAME\--project=$PROJECT_ID\--zone=$ZONE
Verify the resources have been deleted by running the following command. Makesure your TPU is no longer listed. The deletion might take several minutes.
$gcloudcomputetpustpu-vmlist\--zone=$ZONE
Performance notes
Here are a few important details that are particularly relevant to using TPUs inJAX.
Padding
One of the most common causes for slow performance on TPUs is introducinginadvertent padding:
- Arrays in the Cloud TPU are tiled. This entails padding one of thedimensions to a multiple of 8, and a different dimension to a multiple of128.
- The matrix multiplication unit performs best with pairs of large matricesthat minimize the need for padding.
bfloat16 dtype
By default, matrix multiplication in JAX on TPUs usesbfloat16with float32 accumulation. This can be controlled with the precision argument onrelevantjax.numpy function calls (matmul, dot, einsum, etc). In particular:
precision=jax.lax.Precision.DEFAULT: uses mixed bfloat16precision (fastest)precision=jax.lax.Precision.HIGH: uses multiple MXU passes toachieve higher precisionprecision=jax.lax.Precision.HIGHEST: uses even more MXU passesto achieve full float32 precision
JAX also adds the bfloat16 dtype, which you can use to explicitly cast arrays tobfloat16. For example,jax.numpy.array(x, dtype=jax.numpy.bfloat16).
What's next
For more information about Cloud TPU, see:
Except as otherwise noted, the content of this page is licensed under theCreative Commons Attribution 4.0 License, and code samples are licensed under theApache 2.0 License. For details, see theGoogle Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2025-11-24 UTC.