Train an LLM using JAX, Ray Train, and TPU Trillium on GKE

This tutorial shows you how to train the Llama 3 8B large language model (LLM)on Google Kubernetes Engine (GKE) usingMaxText,Ray Train, and TPUs.

This tutorial provides a complete, end-to-end walkthrough, from configuring thenecessary cloud infrastructure to submitting and successfully running thetraining workload on multi-host TPUs.

This tutorial is for Platform admins and operators and Data and AI specialists whowant to learn how to train large models on a distributed, multi-host TPU slice.

Background

The combination of GKE, KubeRay, MaxText, and TPUs provides apowerful and scalable platform for large-scale model training. This sectiondescribes the key technologies used in this guide:

JAX

JAX is a Python library foraccelerator-oriented array computation and program transformation, designed forhigh-performance numerical computing and large-scale machine learning.

JAX provides an extensible system for transforming numerical functions likejax.grad,jax.jit, andjax.vmap, utilizing the XLA compiler to createhighly optimized code that scales efficiently on accelerators like GPUs andTPUs. The core power of JAX lies in its composability, which allows users tocombine these transformations to build complex, high-performance numericalprograms for distributed execution.

MaxText

MaxText is a high-performance, open-sourcelarge language model (LLM) designed for scalability and customizability. MaxTextis built on top of JAX and optimized to run efficiently on Cloud TPU andGPUs.

TPUs

Tensor Processing Units (TPUs), arecustom-designed accelerators created by Google to optimize machine learningworkloads. Unlike general-purpose CPUs or parallel-processing GPUs, TPUs arehighly specialized for the massive matrix and tensor computations at the foundationof deep learning, making them efficient at this specific task. Theprimary advantage of TPUs is performance at scale.

This tutorial uses TPU Trillium, which is the sixth generation of TPUs.For more information, seeBenefits of using TPU Trillium.

KubeRay

KubeRay is a Kubernetes operator thatprovides a unified way to deploy, manage, and monitor Ray applications onKubernetes. The KubeRay operator is installed and managed through theRay on GKE add-on,which is the recommended way to deploy and manage Ray clusters onGKE.

Objectives

This tutorial shows you how to do the following:

  1. Set up a GKE cluster with a multi-host TPU node pool.
  2. Configure KubeRay to manage the distributed training environment.
  3. Build a custom Docker image that contains MaxText, Ray, and JAXdependencies.
  4. Create a Python training script that uses Ray Train'sJaxTrainer toorchestrate the MaxText training loop across the TPU slice.
  5. Define aRayCluster custom resource to provision the head and worker nodeswith the necessary TPU resources.
  6. Submit the training Job to theRayCluster and monitor its progress.
  7. Use Cloud Storage to store model checkpoints.

Before you begin

  • Because this tutorial utilizes TPU Trillium (v6e), select a region orzone with availability. For more information, seeCloud TPU quotas.

Prepare your environment

In this tutorial, you useCloud Shell. Cloud Shell comespreinstalled with thegcloud,helm, andkubectl command-line tools thatare used in this tutorial.

  1. Go to theGoogle Cloud console.

  2. At the top of the Google Cloud console window, click theActivateCloud ShellActivate ShellButton button.

    A Cloud Shell session opens inside a new frame in theGoogle Cloud console and displays a command-line prompt.

  3. Create and activate a Python virtual environment:

    python3-mvenvray-envsourceray-env/bin/activate
  4. Install the Ray CLI and other dependencies:

    pipinstall"ray[default]==2.49.1"
  5. Set the following environment variables:

    exportPROJECT_ID=$(gcloudconfiggetproject)exportPROJECT_NUMBER=$(gcloudprojectsdescribe${PROJECT_ID}--format="value(projectNumber)")exportGS_BUCKET=GS_BUCKETexportKSA_NAME=KSA_NAMEexportNAMESPACE=defaultexportCLUSTER_NAME=CLUSTER_NAMEexportREGION=REGIONexportZONE=ZONEexportARTIFACT_REGISTRY=ARTIFACT_REGISTRY

    Replace the following:

    • GS_BUCKET: the name of the Cloud Storagebucket.
    • KSA_NAME: the name of the Kubernetes ServiceAccount.
    • CLUSTER_NAME: the name of the new cluster.
    • REGION: the region where your TPU Trilliumcapacity is available.
    • ZONE: the zone where your TPU Trillium capacityis available. For more information, seeTPU availability inGKE.
    • ARTIFACT_REGISTRY: the name of the Artifact Registry repository.

Create a GKE cluster

You can configure KubeRay on TPUs in a GKEAutopilot or Standard cluster. We recommend that you use aAutopilot cluster for a fully managed Kubernetes experience. To choosethe GKE mode of operation that's the best fit for your workloads,seeAbout GKE modes of operation.

Autopilot

  1. In Cloud Shell, run the following command:

    gcloudcontainerclusterscreate-auto$CLUSTER_NAME\--enable-ray-operator\--machine-type=n1-standard-16\--location=$REGION
  2. To communicate with your cluster, configurekubectl :

    gcloudcontainerclustersget-credentialsCLUSTER_NAME\--location=$ZONE

Standard

  1. In Cloud Shell, create a Standard cluster that enables theRay operator add-on by running the following command:

    gcloudcontainerclusterscreate$CLUSTER_NAME\--addons=RayOperator\--addonsGcsFuseCsiDriver\--machine-type=n1-standard-16\--workload-pool=$PROJECT_ID.svc.id.goog\--location=$ZONE

    This command also enables theGcsFuseCsiDriver, which allows Pods to mountCloud Storage buckets as local file systems. The cluster creation mighttake several minutes.

  2. To communicate with your cluster, configurekubectl:

    gcloudcontainerclustersget-credentialsCLUSTER_NAME\--location=LOCATION
  3. Create amulti-hostTPU slice node pool:

    gcloudcontainernode-poolscreatev6e-16\--location=$ZONE\--cluster=$CLUSTER_NAME\--machine-type=ct6e-standard-4t\--threads-per-core=1\--tpu-topology=4x4\--num-nodes=4

GKE provisions a node pool consisting of four TPU Trillium (v6e)VMs, which are configured together as a multi-host TPU slice, with a4x4topology, that's ready for distributed training workloads.

TheRay operator-enabledGKE cluster automatically installs KubeRay and theKubeRay TPU webhook in your cluster.

Configure a Cloud Storage bucket and a service account

  1. Create a Cloud Storage bucket for shared checkpoints between themulti-host TPU nodes.

    gsutilmb-p${PROJECT_ID}-cSTANDARD-l${REGION}gs://${GS_BUCKET}
  2. To enable access to the Cloud Storage bucket, create a Kubernetes ServiceAccount:

    kubectlcreateserviceaccount${KSA_NAME}--namespace${NAMESPACE}
  3. To enable access to the Cloud Storage bucket, add the requiredIAM policy bindings to the service account:

    gcloudstoragebucketsadd-iam-policy-bindinggs://${GS_BUCKET}\--member"principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}"\--role"roles/storage.objectUser"

Create the training script

The following script uses Ray Train'sJaxTrainer to run a distributed MaxTexttraining job. The script configures the training environment for a multi-hostTPU slice node pool and runs the MaxText training job on each worker node. Thetrain_loop_per_worker function wraps the MaxText main entry point, and usesthe Ray's distributed scheduler to execute the MaxText trainer on a multi-hostTPU slice.

  1. Save the following Python script asmaxtext_ray_trainer.py:

    importosfromabslimportappimportloggingfromtypingimportSequenceimportrayfromray.train.v2.api.configimportScalingConfig,RunConfigfromray.train.v2.jaximportJaxTrainerdeftrain_loop_per_worker(config):fromMaxText.trainimportmainasmaxtext_mainargv=config["argv"]maxtext_main(argv)defmain(argv:Sequence[str]):trainer=JaxTrainer(train_loop_per_worker=train_loop_per_worker,train_loop_config={"argv":argv},scaling_config=ScalingConfig(use_tpu=True,num_workers=4,topology="4x4",accelerator_type="TPU-V6E",resources_per_worker={"TPU":4},placement_strategy="SPREAD",),run_config=RunConfig(name="maxtext_jaxtrainer",worker_runtime_env={"env_vars":{"JAX_PLATFORMS":"tpu","ENABLE_PJRT_COMPATIBILITY":"true","TPU_SLICE_BUILDER_DUMP_CHIP_FORCE":"true","TPU_SLICE_BUILDER_DUMP_ICI":"true","XLA_FLAGS":"--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto",}},),)result=trainer.fit()logging.info("Training complete!")ray.shutdown()if__name__=="__main__":app.run(main)
  2. To host the custom image, create an Artifact Registry repository:

    gcloudartifactsrepositoriescreate${ARTIFACT_REGISTRY}\--repository-format=docker--location=${REGION} &&\gcloudauthconfigure-docker${REGION}-docker.pkg.dev
  3. To build an image that includes Ray and MaxText dependencies for training,create aDockerfile:

    # Start from a Ray base image which includes JaxTrainer API.# Maxtext with TPU requires Python 3.12.FROMrayproject/ray:2.49.1-py312USERrootRUNgroupadd-rray2>/dev/null||true &&usermod-grayrayRUNsudoapt-getupdate-y\  &&sudoapt-getinstall--no-install-recommends-ygit\  &&sudorm-rf/var/lib/apt/lists/*WORKDIR/app# Clone the Maxtext repo and build from source, installing TPU dependencies.RUNgitclonehttps://github.com/AI-Hypercomputer/maxtext.gitRUNpipinstall--no-cache-diruvRUNcdmaxtext &&\uvpipinstall--no-cache--system-e.[tpu]--resolution=lowest &&\install_maxtext_github_deps# Copy the Ray Maxtext trainer to run on the remote container.COPYmaxtext_ray_trainer.py.RUNchown-Rray:ray.ENVPYTHONPATH=/app/maxtext/src:/app/maxtext:/appUSERray
    Note: This tutorial uses a custom image because the dependencies are too large to install them by using the Rayruntime_env command.
  4. Build, tag, and push the Docker image to Artifact Registry:

    exportDOCKER_IMAGE=${REGION}-docker.pkg.dev/${PROJECT_ID}/${ARTIFACT_REGISTRY}/ray-maxtext:latestgcloudbuildssubmit--tag${DOCKER_IMAGE}

Train the model

  1. Save the following sample manifest asmaxtext-tpu-cluster.yaml:

    apiVersion:ray.io/v1kind:RayClustermetadata:name:maxtext-tpu-clusterspec:headGroupSpec:rayStartParams:{}template:metadata:annotations:gke-gcsfuse/volumes:"true"gke-gcsfuse/cpu-limit:"0"gke-gcsfuse/memory-limit:"0"gke-gcsfuse/ephemeral-storage-limit:"0"spec:serviceAccountName:${KSA_NAME}containers:-name:ray-headimage:${DOCKER_IMAGE}imagePullPolicy:IfNotPresentports:-containerPort:6379name:gcs-server-containerPort:8265name:dashboard-containerPort:10001name:clientresources:limits:memory:"16Gi"requests:cpu:"8"memory:"16Gi"volumeMounts:-name:gcs-fuse-csi-ephemeralmountPath:/data-name:dshmmountPath:/dev/shmvolumes:-name:gcs-fuse-cacheemptyDir:medium:Memory-name:dshmemptyDir:medium:Memory-name:gcs-fuse-csi-ephemeralcsi:driver:gcsfuse.csi.storage.gke.iovolumeAttributes:bucketName:${GS_BUCKET}mountOptions:"implicit-dirs"workerGroupSpecs:-replicas:1numOfHosts:4groupName:tpu-grouprayStartParams:{}template:metadata:annotations:gke-gcsfuse/volumes:"true"gke-gcsfuse/cpu-limit:"0"gke-gcsfuse/memory-limit:"0"gke-gcsfuse/ephemeral-storage-limit:"0"spec:serviceAccountName:${KSA_NAME}containers:-name:ray-workerimage:${DOCKER_IMAGE}imagePullPolicy:IfNotPresentresources:limits:memory:200Ggoogle.com/tpu:"4"requests:cpu:"8"memory:200Ggoogle.com/tpu:"4"env:-name:JAX_PLATFORMSvalue:tpu-name:ENABLE_PJRT_COMPATIBILITYvalue:"true"volumeMounts:-name:gcs-fuse-csi-ephemeralmountPath:/data-name:dshmmountPath:/dev/shmvolumes:-name:gcs-fuse-cacheemptyDir:medium:Memory-name:dshmemptyDir:medium:Memory-name:gcs-fuse-csi-ephemeralcsi:driver:gcsfuse.csi.storage.gke.iovolumeAttributes:bucketName:${GS_BUCKET}mountOptions:"implicit-dirs"nodeSelector:cloud.google.com/gke-tpu-accelerator:tpu-v6e-slicecloud.google.com/gke-tpu-topology:4x4

    The preceding RayCluster spec creates a TPU worker group with four workers(numOfHosts: 4) per replica. Each worker requests four TPU chips(google.com/tpu: "4"). The workers will be scheduled on a node that runsTPU Trillium (tpu-v6e-slice), and that's part of the same colocatedmulti-host slice. KubeRay scales all four workers atomically, and therequired JAX environment variables, as well as Pod Affinities forscheduling, are bootstrapped by GKE through a mutatingwebhook.

  2. To configure required values in the YAML file, create the RayCluster usingenvsubst:

    envsubst <maxtext-tpu-cluster.yaml|kubectlapply-f-
  3. Verify the cluster is ready and running:

    kubectlgetrayclustersmaxtext-tpu-cluster

    The output should be similar to the following:

    NAME                  DESIRED WORKERS   AVAILABLE WORKERS   CPUS   MEMORY        GPUS   STATUS   AGEmaxtext-tpu-cluster   4                 4                   40     798027216Ki   0      ready    11m
  4. To access the Ray Dashboard through the Ray head service, establish aport-forwarding session:

    kubectlport-forwardsvc/maxtext-tpu-cluster-head-svc8265:82652>&1>/dev/null&
  5. Verify the RayCluster is reachable from your local environment:

    raylistnodes--addresshttp://localhost:8265

    The output should be similar to the following:

    ======== List: 2025-09-13 03:53:16.988269 ========Stats:------------------------------Total: 5Table:------------------------------    NODE_ID                                                   NODE_IP    IS_HEAD_NODE    STATE    STATE_MESSAGE    NODE_NAME    RESOURCES_TOTAL                  LABELS0  92c79d04c34b659c1e3044f7642ad3fd47eb16f290785237149fab56  10.84.0.9(...)
  6. Submit the JaxTrainer script to the RayCluster and check that the RayJob completes successfully:

    rayjobsubmit\--addresshttp://localhost:8265\--python/app/maxtext_ray_trainer.py\/app/maxtext/src/MaxText/configs/base.yml\base_output_directory=/data/\dataset_type=synthetic\per_device_batch_size=1\max_target_length=4096\model_name=llama3-8b\steps=100\ici_fsdp_parallelism=4\ici_tensor_parallelism=4\run_name=rayjob-8b-4096-tp4-4x4

    The preceding command submits the Python script, which calls the JaxTrainerRay code to the RayCluster. Theray job submit command includes someMaxText-specific arguments to pass to the model configuration.

    Note: It might take a couple of minutes for the script to complete.

    In your terminal, you shouldsee output similar to the following:

    (RayTrainWorker pid=21663, ip=10.12.3.6) completed step: 99, seconds: 1.100, TFLOP/s/device: 179.739, Tokens/s/device: 3725.218, total_weights: 65536, loss: 0.000 [repeated 3x across cluster]------------------------------------------Job 'raysubmit_zCrJcWnuymMQv4C3' succeeded------------------------------------------
Success: You've completed training by using Ray Train's JaxTrainer and MaxText ondistributed multi-host TPU node pools.

Clean up

To avoid incurring charges to your Google Cloud account for the resources usedin this tutorial, either delete the project that contains the resources, or keepthe project and delete the individual resources.

  1. Delete the RayCluster:

    kubectldeleterayclustermaxtext-tpu-cluster
  2. Delete the GKE cluster:

    gcloudcontainerclustersdelete$CLUSTER_NAME--zone=$ZONE
  3. Delete the Cloud Storage bucket:

    gsutilrm-rgs://${GS_BUCKET}
  4. Delete the Artifact Registry repository:

    gcloudartifactsrepositoriesdelete${ARTIFACT_REGISTRY}--location=${REGION}--quiet

What's next

Except as otherwise noted, the content of this page is licensed under theCreative Commons Attribution 4.0 License, and code samples are licensed under theApache 2.0 License. For details, see theGoogle Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2025-11-24 UTC.