Train an LLM using JAX, Ray Train, and TPU Trillium on GKE
This tutorial shows you how to train the Llama 3 8B large language model (LLM)on Google Kubernetes Engine (GKE) usingMaxText,Ray Train, and TPUs.
This tutorial provides a complete, end-to-end walkthrough, from configuring thenecessary cloud infrastructure to submitting and successfully running thetraining workload on multi-host TPUs.
This tutorial is for Platform admins and operators and Data and AI specialists whowant to learn how to train large models on a distributed, multi-host TPU slice.
Background
The combination of GKE, KubeRay, MaxText, and TPUs provides apowerful and scalable platform for large-scale model training. This sectiondescribes the key technologies used in this guide:
JAX
JAX is a Python library foraccelerator-oriented array computation and program transformation, designed forhigh-performance numerical computing and large-scale machine learning.
JAX provides an extensible system for transforming numerical functions likejax.grad,jax.jit, andjax.vmap, utilizing the XLA compiler to createhighly optimized code that scales efficiently on accelerators like GPUs andTPUs. The core power of JAX lies in its composability, which allows users tocombine these transformations to build complex, high-performance numericalprograms for distributed execution.
MaxText
MaxText is a high-performance, open-sourcelarge language model (LLM) designed for scalability and customizability. MaxTextis built on top of JAX and optimized to run efficiently on Cloud TPU andGPUs.
TPUs
Tensor Processing Units (TPUs), arecustom-designed accelerators created by Google to optimize machine learningworkloads. Unlike general-purpose CPUs or parallel-processing GPUs, TPUs arehighly specialized for the massive matrix and tensor computations at the foundationof deep learning, making them efficient at this specific task. Theprimary advantage of TPUs is performance at scale.
This tutorial uses TPU Trillium, which is the sixth generation of TPUs.For more information, seeBenefits of using TPU Trillium.
KubeRay
KubeRay is a Kubernetes operator thatprovides a unified way to deploy, manage, and monitor Ray applications onKubernetes. The KubeRay operator is installed and managed through theRay on GKE add-on,which is the recommended way to deploy and manage Ray clusters onGKE.
Objectives
This tutorial shows you how to do the following:
- Set up a GKE cluster with a multi-host TPU node pool.
- Configure KubeRay to manage the distributed training environment.
- Build a custom Docker image that contains MaxText, Ray, and JAXdependencies.
- Create a Python training script that uses Ray Train's
JaxTrainertoorchestrate the MaxText training loop across the TPU slice. - Define a
RayClustercustom resource to provision the head and worker nodeswith the necessary TPU resources. - Submit the training Job to the
RayClusterand monitor its progress. - Use Cloud Storage to store model checkpoints.
Before you begin
- Sign in to your Google Cloud account. If you're new to Google Cloud, create an account to evaluate how our products perform in real-world scenarios. New customers also get $300 in free credits to run, test, and deploy workloads.
Install the Google Cloud CLI.
Note: If you installed the gcloud CLI previously, make sure you have the latest version by runninggcloud components update.If you're using an external identity provider (IdP), you must first sign in to the gcloud CLI with your federated identity.
Toinitialize the gcloud CLI, run the following command:
gcloudinit
Create or select a Google Cloud project.
Note: If you don't plan to keep the resources that you create in this procedure, create a project instead of selecting an existing project. After you finish these steps, you can delete the project, removing all resources associated with the project.Roles required to select or create a project
- Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
- Create a project: To create a project, you need the Project Creator role (
roles/resourcemanager.projectCreator), which contains theresourcemanager.projects.createpermission.Learn how to grant roles.
Create a Google Cloud project:
gcloud projects createPROJECT_ID
Replace
PROJECT_IDwith a name for the Google Cloud project you are creating.Select the Google Cloud project that you created:
gcloud config set projectPROJECT_ID
Replace
PROJECT_IDwith your Google Cloud project name.
Verify that billing is enabled for your Google Cloud project.
Enable the required API:
Roles required to enable APIs
To enable APIs, you need the Service Usage Admin IAM role (
roles/serviceusage.serviceUsageAdmin), which contains theserviceusage.services.enablepermission.Learn how to grant roles.gcloudservicesenablecontainer.googleapis.comInstall the Google Cloud CLI.
Note: If you installed the gcloud CLI previously, make sure you have the latest version by runninggcloud components update.If you're using an external identity provider (IdP), you must first sign in to the gcloud CLI with your federated identity.
Toinitialize the gcloud CLI, run the following command:
gcloudinit
Create or select a Google Cloud project.
Note: If you don't plan to keep the resources that you create in this procedure, create a project instead of selecting an existing project. After you finish these steps, you can delete the project, removing all resources associated with the project.Roles required to select or create a project
- Select a project: Selecting a project doesn't require a specific IAM role—you can select any project that you've been granted a role on.
- Create a project: To create a project, you need the Project Creator role (
roles/resourcemanager.projectCreator), which contains theresourcemanager.projects.createpermission.Learn how to grant roles.
Create a Google Cloud project:
gcloud projects createPROJECT_ID
Replace
PROJECT_IDwith a name for the Google Cloud project you are creating.Select the Google Cloud project that you created:
gcloud config set projectPROJECT_ID
Replace
PROJECT_IDwith your Google Cloud project name.
Verify that billing is enabled for your Google Cloud project.
Enable the required API:
Roles required to enable APIs
To enable APIs, you need the Service Usage Admin IAM role (
roles/serviceusage.serviceUsageAdmin), which contains theserviceusage.services.enablepermission.Learn how to grant roles.gcloudservicesenablecontainer.googleapis.comGrant roles to your user account. Run the following command once for each of the following IAM roles:
roles/container.admin, roles/iam.serviceAccountAdmingcloudprojectsadd-iam-policy-bindingPROJECT_ID--member="user:USER_IDENTIFIER"--role=ROLE
Replace the following:
PROJECT_ID: Your project ID.USER_IDENTIFIER: The identifier for your user account. For example,myemail@example.com.ROLE: The IAM role that you grant to your user account.
- Because this tutorial utilizes TPU Trillium (v6e), select a region orzone with availability. For more information, seeCloud TPU quotas.
Prepare your environment
In this tutorial, you useCloud Shell. Cloud Shell comespreinstalled with thegcloud,helm, andkubectl command-line tools thatare used in this tutorial.
Go to theGoogle Cloud console.
At the top of the Google Cloud console window, click theActivateCloud Shell
button.A Cloud Shell session opens inside a new frame in theGoogle Cloud console and displays a command-line prompt.
Create and activate a Python virtual environment:
python3-mvenvray-envsourceray-env/bin/activateInstall the Ray CLI and other dependencies:
pipinstall"ray[default]==2.49.1"Set the following environment variables:
exportPROJECT_ID=$(gcloudconfiggetproject)exportPROJECT_NUMBER=$(gcloudprojectsdescribe${PROJECT_ID}--format="value(projectNumber)")exportGS_BUCKET=GS_BUCKETexportKSA_NAME=KSA_NAMEexportNAMESPACE=defaultexportCLUSTER_NAME=CLUSTER_NAMEexportREGION=REGIONexportZONE=ZONEexportARTIFACT_REGISTRY=ARTIFACT_REGISTRYReplace the following:
GS_BUCKET: the name of the Cloud Storagebucket.KSA_NAME: the name of the Kubernetes ServiceAccount.CLUSTER_NAME: the name of the new cluster.REGION: the region where your TPU Trilliumcapacity is available.ZONE: the zone where your TPU Trillium capacityis available. For more information, seeTPU availability inGKE.ARTIFACT_REGISTRY: the name of the Artifact Registry repository.
Create a GKE cluster
You can configure KubeRay on TPUs in a GKEAutopilot or Standard cluster. We recommend that you use aAutopilot cluster for a fully managed Kubernetes experience. To choosethe GKE mode of operation that's the best fit for your workloads,seeAbout GKE modes of operation.
Autopilot
In Cloud Shell, run the following command:
gcloudcontainerclusterscreate-auto$CLUSTER_NAME\--enable-ray-operator\--machine-type=n1-standard-16\--location=$REGIONTo communicate with your cluster, configure
kubectl:gcloudcontainerclustersget-credentialsCLUSTER_NAME\--location=$ZONE
Standard
In Cloud Shell, create a Standard cluster that enables theRay operator add-on by running the following command:
gcloudcontainerclusterscreate$CLUSTER_NAME\--addons=RayOperator\--addonsGcsFuseCsiDriver\--machine-type=n1-standard-16\--workload-pool=$PROJECT_ID.svc.id.goog\--location=$ZONEThis command also enables the
GcsFuseCsiDriver, which allows Pods to mountCloud Storage buckets as local file systems. The cluster creation mighttake several minutes.To communicate with your cluster, configure
kubectl:gcloudcontainerclustersget-credentialsCLUSTER_NAME\--location=LOCATIONCreate amulti-hostTPU slice node pool:
gcloudcontainernode-poolscreatev6e-16\--location=$ZONE\--cluster=$CLUSTER_NAME\--machine-type=ct6e-standard-4t\--threads-per-core=1\--tpu-topology=4x4\--num-nodes=4
GKE provisions a node pool consisting of four TPU Trillium (v6e)VMs, which are configured together as a multi-host TPU slice, with a4x4topology, that's ready for distributed training workloads.
TheRay operator-enabledGKE cluster automatically installs KubeRay and theKubeRay TPU webhook in your cluster.
Configure a Cloud Storage bucket and a service account
Create a Cloud Storage bucket for shared checkpoints between themulti-host TPU nodes.
gsutilmb-p${PROJECT_ID}-cSTANDARD-l${REGION}gs://${GS_BUCKET}To enable access to the Cloud Storage bucket, create a Kubernetes ServiceAccount:
kubectlcreateserviceaccount${KSA_NAME}--namespace${NAMESPACE}To enable access to the Cloud Storage bucket, add the requiredIAM policy bindings to the service account:
gcloudstoragebucketsadd-iam-policy-bindinggs://${GS_BUCKET}\--member"principal://iam.googleapis.com/projects/${PROJECT_NUMBER}/locations/global/workloadIdentityPools/${PROJECT_ID}.svc.id.goog/subject/ns/${NAMESPACE}/sa/${KSA_NAME}"\--role"roles/storage.objectUser"
Create the training script
The following script uses Ray Train'sJaxTrainer to run a distributed MaxTexttraining job. The script configures the training environment for a multi-hostTPU slice node pool and runs the MaxText training job on each worker node. Thetrain_loop_per_worker function wraps the MaxText main entry point, and usesthe Ray's distributed scheduler to execute the MaxText trainer on a multi-hostTPU slice.
Save the following Python script as
maxtext_ray_trainer.py:importosfromabslimportappimportloggingfromtypingimportSequenceimportrayfromray.train.v2.api.configimportScalingConfig,RunConfigfromray.train.v2.jaximportJaxTrainerdeftrain_loop_per_worker(config):fromMaxText.trainimportmainasmaxtext_mainargv=config["argv"]maxtext_main(argv)defmain(argv:Sequence[str]):trainer=JaxTrainer(train_loop_per_worker=train_loop_per_worker,train_loop_config={"argv":argv},scaling_config=ScalingConfig(use_tpu=True,num_workers=4,topology="4x4",accelerator_type="TPU-V6E",resources_per_worker={"TPU":4},placement_strategy="SPREAD",),run_config=RunConfig(name="maxtext_jaxtrainer",worker_runtime_env={"env_vars":{"JAX_PLATFORMS":"tpu","ENABLE_PJRT_COMPATIBILITY":"true","TPU_SLICE_BUILDER_DUMP_CHIP_FORCE":"true","TPU_SLICE_BUILDER_DUMP_ICI":"true","XLA_FLAGS":"--xla_dump_to=/tmp/xla_dump_file --xla_dump_hlo_as_proto",}},),)result=trainer.fit()logging.info("Training complete!")ray.shutdown()if__name__=="__main__":app.run(main)To host the custom image, create an Artifact Registry repository:
gcloudartifactsrepositoriescreate${ARTIFACT_REGISTRY}\--repository-format=docker--location=${REGION} &&\gcloudauthconfigure-docker${REGION}-docker.pkg.devTo build an image that includes Ray and MaxText dependencies for training,create a
Dockerfile:Note: This tutorial uses a custom image because the dependencies are too large to install them by using the Ray# Start from a Ray base image which includes JaxTrainer API.# Maxtext with TPU requires Python 3.12.FROMrayproject/ray:2.49.1-py312USERrootRUNgroupadd-rray2>/dev/null||true &&usermod-grayrayRUNsudoapt-getupdate-y\ &&sudoapt-getinstall--no-install-recommends-ygit\ &&sudorm-rf/var/lib/apt/lists/*WORKDIR/app# Clone the Maxtext repo and build from source, installing TPU dependencies.RUNgitclonehttps://github.com/AI-Hypercomputer/maxtext.gitRUNpipinstall--no-cache-diruvRUNcdmaxtext &&\uvpipinstall--no-cache--system-e.[tpu]--resolution=lowest &&\install_maxtext_github_deps# Copy the Ray Maxtext trainer to run on the remote container.COPYmaxtext_ray_trainer.py.RUNchown-Rray:ray.ENVPYTHONPATH=/app/maxtext/src:/app/maxtext:/appUSERrayruntime_envcommand.Build, tag, and push the Docker image to Artifact Registry:
exportDOCKER_IMAGE=${REGION}-docker.pkg.dev/${PROJECT_ID}/${ARTIFACT_REGISTRY}/ray-maxtext:latestgcloudbuildssubmit--tag${DOCKER_IMAGE}
Train the model
Save the following sample manifest as
maxtext-tpu-cluster.yaml:apiVersion:ray.io/v1kind:RayClustermetadata:name:maxtext-tpu-clusterspec:headGroupSpec:rayStartParams:{}template:metadata:annotations:gke-gcsfuse/volumes:"true"gke-gcsfuse/cpu-limit:"0"gke-gcsfuse/memory-limit:"0"gke-gcsfuse/ephemeral-storage-limit:"0"spec:serviceAccountName:${KSA_NAME}containers:-name:ray-headimage:${DOCKER_IMAGE}imagePullPolicy:IfNotPresentports:-containerPort:6379name:gcs-server-containerPort:8265name:dashboard-containerPort:10001name:clientresources:limits:memory:"16Gi"requests:cpu:"8"memory:"16Gi"volumeMounts:-name:gcs-fuse-csi-ephemeralmountPath:/data-name:dshmmountPath:/dev/shmvolumes:-name:gcs-fuse-cacheemptyDir:medium:Memory-name:dshmemptyDir:medium:Memory-name:gcs-fuse-csi-ephemeralcsi:driver:gcsfuse.csi.storage.gke.iovolumeAttributes:bucketName:${GS_BUCKET}mountOptions:"implicit-dirs"workerGroupSpecs:-replicas:1numOfHosts:4groupName:tpu-grouprayStartParams:{}template:metadata:annotations:gke-gcsfuse/volumes:"true"gke-gcsfuse/cpu-limit:"0"gke-gcsfuse/memory-limit:"0"gke-gcsfuse/ephemeral-storage-limit:"0"spec:serviceAccountName:${KSA_NAME}containers:-name:ray-workerimage:${DOCKER_IMAGE}imagePullPolicy:IfNotPresentresources:limits:memory:200Ggoogle.com/tpu:"4"requests:cpu:"8"memory:200Ggoogle.com/tpu:"4"env:-name:JAX_PLATFORMSvalue:tpu-name:ENABLE_PJRT_COMPATIBILITYvalue:"true"volumeMounts:-name:gcs-fuse-csi-ephemeralmountPath:/data-name:dshmmountPath:/dev/shmvolumes:-name:gcs-fuse-cacheemptyDir:medium:Memory-name:dshmemptyDir:medium:Memory-name:gcs-fuse-csi-ephemeralcsi:driver:gcsfuse.csi.storage.gke.iovolumeAttributes:bucketName:${GS_BUCKET}mountOptions:"implicit-dirs"nodeSelector:cloud.google.com/gke-tpu-accelerator:tpu-v6e-slicecloud.google.com/gke-tpu-topology:4x4The preceding RayCluster spec creates a TPU worker group with four workers(
numOfHosts: 4) per replica. Each worker requests four TPU chips(google.com/tpu: "4"). The workers will be scheduled on a node that runsTPU Trillium (tpu-v6e-slice), and that's part of the same colocatedmulti-host slice. KubeRay scales all four workers atomically, and therequired JAX environment variables, as well as Pod Affinities forscheduling, are bootstrapped by GKE through a mutatingwebhook.To configure required values in the YAML file, create the RayCluster using
envsubst:envsubst <maxtext-tpu-cluster.yaml|kubectlapply-f-Verify the cluster is ready and running:
kubectlgetrayclustersmaxtext-tpu-clusterThe output should be similar to the following:
NAME DESIRED WORKERS AVAILABLE WORKERS CPUS MEMORY GPUS STATUS AGEmaxtext-tpu-cluster 4 4 40 798027216Ki 0 ready 11mTo access the Ray Dashboard through the Ray head service, establish aport-forwarding session:
kubectlport-forwardsvc/maxtext-tpu-cluster-head-svc8265:82652>&1>/dev/null&Verify the RayCluster is reachable from your local environment:
raylistnodes--addresshttp://localhost:8265The output should be similar to the following:
======== List: 2025-09-13 03:53:16.988269 ========Stats:------------------------------Total: 5Table:------------------------------ NODE_ID NODE_IP IS_HEAD_NODE STATE STATE_MESSAGE NODE_NAME RESOURCES_TOTAL LABELS0 92c79d04c34b659c1e3044f7642ad3fd47eb16f290785237149fab56 10.84.0.9(...)Submit the JaxTrainer script to the RayCluster and check that the RayJob completes successfully:
rayjobsubmit\--addresshttp://localhost:8265\--python/app/maxtext_ray_trainer.py\/app/maxtext/src/MaxText/configs/base.yml\base_output_directory=/data/\dataset_type=synthetic\per_device_batch_size=1\max_target_length=4096\model_name=llama3-8b\steps=100\ici_fsdp_parallelism=4\ici_tensor_parallelism=4\run_name=rayjob-8b-4096-tp4-4x4The preceding command submits the Python script, which calls the JaxTrainerRay code to the RayCluster. The
Note: It might take a couple of minutes for the script to complete.ray job submitcommand includes someMaxText-specific arguments to pass to the model configuration.In your terminal, you shouldsee output similar to the following:
(RayTrainWorker pid=21663, ip=10.12.3.6) completed step: 99, seconds: 1.100, TFLOP/s/device: 179.739, Tokens/s/device: 3725.218, total_weights: 65536, loss: 0.000 [repeated 3x across cluster]------------------------------------------Job 'raysubmit_zCrJcWnuymMQv4C3' succeeded------------------------------------------
Clean up
To avoid incurring charges to your Google Cloud account for the resources usedin this tutorial, either delete the project that contains the resources, or keepthe project and delete the individual resources.
Delete the RayCluster:
kubectldeleterayclustermaxtext-tpu-clusterDelete the GKE cluster:
gcloudcontainerclustersdelete$CLUSTER_NAME--zone=$ZONEDelete the Cloud Storage bucket:
gsutilrm-rgs://${GS_BUCKET}Delete the Artifact Registry repository:
gcloudartifactsrepositoriesdelete${ARTIFACT_REGISTRY}--location=${REGION}--quiet
What's next
- Learn aboutRay on Kubernetes.
- Learn how toServe vLLM on GKE with TPUs.
- Learn how toServe SDXL on GKE with TPUs.
- Learn moreAbout TPUs in GKE.
Except as otherwise noted, the content of this page is licensed under theCreative Commons Attribution 4.0 License, and code samples are licensed under theApache 2.0 License. For details, see theGoogle Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2025-11-24 UTC.