Run a calculation on a Cloud TPU VM using JAX

Note: This page applies to the Cloud TPU API. For Ironwood (TPU7x), you must use Google Kubernetes Engine (GKE). For more information, seeAbout TPUs in GKE.

This document provides a brief introduction to working with JAX and Cloud TPU.

Note: This example shows how to run code on a v5litepod-8 (v5e) TPU which is asingle-host TPU. Single-host TPUs have only 1 TPU VM. To run code on TPUs withmore than one TPU VM (for example, v5litepod-16 or larger), seeRun JAX code on Cloud TPU slices.

Before you begin

Before running the commands in this document, you must create a Google Cloudaccount, install the Google Cloud CLI, and configure thegcloud command. Formore information, seeSet up the Cloud TPU environment.

Create a Cloud TPU VM usinggcloud

  1. Define some environment variables to make commands easier to use.

    exportPROJECT_ID=your-project-idexportTPU_NAME=your-tpu-nameexportZONE=us-east5-aexportACCELERATOR_TYPE=v5litepod-8exportRUNTIME_VERSION=v2-alpha-tpuv5-lite

    Environment variable descriptions

    VariableDescription
    PROJECT_ID Your Google Cloud project ID. Use an existing project orcreate a new one.
    TPU_NAMEThe name of the TPU.
    ZONE The zone in which to create the TPU VM. For more information about supported zones, seeTPU regions and zones.
    ACCELERATOR_TYPE The accelerator type specifies the version and size of the Cloud TPU you want to create. For more information about supported accelerator types for each TPU version, seeTPU versions.
    RUNTIME_VERSIONThe Cloud TPUsoftware version.

  2. Create your TPU VM by running the following command from a Cloud Shell oryour computer terminal where theGoogle Cloud CLIis installed.

    $gcloudcomputetpustpu-vmcreate$TPU_NAME\--project=$PROJECT_ID\--zone=$ZONE\--accelerator-type=$ACCELERATOR_TYPE\--version=$RUNTIME_VERSION

Connect to your Cloud TPU VM

Connect to your TPU VM over SSH by using the following command:

$gcloudcomputetpustpu-vmssh$TPU_NAME\--project=$PROJECT_ID\--zone=$ZONE

If you fail to connect to a TPU VM using SSH, it might be because the TPU VMdoesn't have an external IP address. To access a TPU VM without an external IPaddress, follow the instructions inConnect to a TPU VM without a public IPaddress.

Install JAX on your Cloud TPU VM

(vm)$pipinstalljax[tpu]-fhttps://storage.googleapis.com/jax-releases/libtpu_releases.html

System check

Verify that JAX can access the TPU and can run basic operations:

  1. Start the Python 3 interpreter:

    (vm)$python3
    >>>importjax
  2. Display the number of TPU cores available:

    >>>jax.device_count()

The number of TPU cores is displayed. The number of cores displayed is dependenton the TPU version you are using. For more information, seeTPU versions.

Perform a calculation

>>>jax.numpy.add(1,1)

The result of the numpy add is displayed:

Output from the command:

Array(2,dtype=int32,weak_type=True)

Exit the Python interpreter

>>>exit()

Running JAX code on a TPU VM

You can now run any JAX code you want. TheFlax examplesare a great place to start with running standard ML models in JAX. For example,to train a basic MNIST convolutional network:

  1. Install Flax examples dependencies:

    (vm)$pipinstall--upgradeclu(vm)$pipinstalltensorflow(vm)$pipinstalltensorflow_datasets
  2. Install Flax:

    (vm)$gitclonehttps://github.com/google/flax.git(vm)$pipinstall--userflax
  3. Run the Flax MNIST training script:

    (vm)$cdflax/examples/mnist(vm)$python3main.py--workdir=/tmp/mnist\--config=configs/default.py\--config.learning_rate=0.05\--config.num_epochs=5

The script downloads the dataset and starts training. The script output shouldlook like this:

I021418:00:50.660087140369022753856train.py:146]epoch:1,train_loss:0.2421,train_accuracy:92.97,test_loss:0.0615,test_accuracy:97.88I021418:00:52.015867140369022753856train.py:146]epoch:2,train_loss:0.0594,train_accuracy:98.16,test_loss:0.0412,test_accuracy:98.72I021418:00:53.377511140369022753856train.py:146]epoch:3,train_loss:0.0418,train_accuracy:98.72,test_loss:0.0296,test_accuracy:99.04I021418:00:54.727168140369022753856train.py:146]epoch:4,train_loss:0.0305,train_accuracy:99.06,test_loss:0.0257,test_accuracy:99.15I021418:00:56.082807140369022753856train.py:146]epoch:5,train_loss:0.0252,train_accuracy:99.20,test_loss:0.0263,test_accuracy:99.18

Clean up

To avoid incurring charges to your Google Cloud account for the resources used on this page, follow these steps.

When you are done with your TPU VM, follow these steps to clean up your resources.

  1. Disconnect from the Cloud TPU instance, if you have not already done so:

    (vm)$exit

    Your prompt should now be username@projectname, showing you are in the Cloud Shell.

  2. Delete your Cloud TPU:

    $gcloudcomputetpustpu-vmdelete$TPU_NAME\--project=$PROJECT_ID\--zone=$ZONE
  3. Verify the resources have been deleted by running the following command. Makesure your TPU is no longer listed. The deletion might take several minutes.

    $gcloudcomputetpustpu-vmlist\--zone=$ZONE

Performance notes

Here are a few important details that are particularly relevant to using TPUs inJAX.

Padding

One of the most common causes for slow performance on TPUs is introducinginadvertent padding:

  • Arrays in the Cloud TPU are tiled. This entails padding one of thedimensions to a multiple of 8, and a different dimension to a multiple of128.
  • The matrix multiplication unit performs best with pairs of large matricesthat minimize the need for padding.

bfloat16 dtype

By default, matrix multiplication in JAX on TPUs usesbfloat16with float32 accumulation. This can be controlled with the precision argument onrelevantjax.numpy function calls (matmul, dot, einsum, etc). In particular:

  • precision=jax.lax.Precision.DEFAULT: uses mixed bfloat16precision (fastest)
  • precision=jax.lax.Precision.HIGH: uses multiple MXU passes toachieve higher precision
  • precision=jax.lax.Precision.HIGHEST: uses even more MXU passesto achieve full float32 precision

JAX also adds the bfloat16 dtype, which you can use to explicitly cast arrays tobfloat16. For example,jax.numpy.array(x, dtype=jax.numpy.bfloat16).

What's next

For more information about Cloud TPU, see:

Except as otherwise noted, the content of this page is licensed under theCreative Commons Attribution 4.0 License, and code samples are licensed under theApache 2.0 License. For details, see theGoogle Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2025-11-24 UTC.