Cloud TPU performance guide
Your first step when troubleshooting TPU performance is to profile your model.For more information on capturing a performance profile, seeProfiling your model on Cloud TPU.
TPU model performance
This section describes general issues that can reduce model performance andhow you can address them.
Model is input bound
TPUs perform calculations very fast. To ensure the TPU is not idle, it isimportant to make sure there is a steady stream of data being loaded onto theTPU. How this is done depends on how you load and preprocess your dataset.For example, you can read datafiles in parallel usingtf.data.TFRecordset()and the
num_parallel_readsparameter.Batch size is too small because of sharding (splitting batches across cores)
The TPU runtime splits a batch across all 8 cores of a TPU device (forexample v2-8 or v3-8). If you specify a global batch size of 128, each core receivesa batch size of 16 (128 / 8).
For optimum memory usage, use the largest batch size that fits into TPUmemory. Each TPU core uses two-dimensional 8 X 128 vector registersfor processing matrix multiplications.In general, your batch size should be evenly divisible by 8 or 128.
Memory Management Tuning
You can use the
TPU_PREMAPPED_BUFFER_SIZEenvironment variables tofine-tune low-level runtime behaviors.
Description:
TPU_PREMAPPED_BUFFER_SIZEsets the size of the hostmemory buffer(in bytes) that is pre-mapped and pinned for use by the TPU runtime fordata transfers (for example, DMA). The default value is 4294967296 bytes.The value must be a multiple of 2^12 (4KB = 4 * 1024 Bytes = 4096 = 2^12).The following examples are valid TPU_PRE_MAPPED_BUFFER_SIZE values.
17179869184=2^34=2^22*2^12(2^224KBpageswillbepremapped).40000000000=5^10*2^12=(5^104KBpageswillbepremapped).Impact: Increasing this size can potentially improve data transferperformance between the host and TPU device, especially for workloadswith large tensors or frequent host-device communication. However,it also increases the amount of pinned host memory, reducing memoryavailable for other processes.
Buffer size
If the pre-mapped buffer region isn't large enoughto allocate memory during program runtime, the workload will failand return a
RESOURCE_EXHAUSTEDerror similar to:"Allocating buffer from premmaped region failed with:
RESOURCE_EXHAUSTED:Attempting to allocateallocation_size. That was not possible. Thereareavailable_sizefree."If the buffer is excessively large, TPUinitialization can take much longer (potentially more than 15 seconds),making it seem as if the TPU is stuck.
To diagnose this, inspect the TPU runtime logs. These logsdetail the operations being performed, including the pre-mapping ofbuffers. You can find the logs at /tmp/tpu_logs/tpu_driver.INFO orprint them directly to the console by setting the environment variableTPU_STDERR_LOG_LEVEL=0. This setting will generate output similar to:
I060412:45:24.92623362136tpu_hal.cc:214]Startingpremappedmemorymanagerinitialization...I060412:45:29.41121862136system.cc:1059]tpu::Systeminitialized,currenthostid:0,logicaldeviceids:0I060412:45:29.41124461600tfrt_tpu_system_state.cc:216]CreateTpuSystemState:TPUinitializationissuccessfulandittook5.583190661sI060412:45:29.41126761600tfrt_tpu_system_state.cc:220]CreateTpuSystemState:usingTPUhostpremappedbufferofsize:4294967296```ThisoutputwilltellyouhowlongittooktoinitializetheTPUandthesizeofthepremappedbuffer.Usage: If the premapped buffer is too small or too large,you can manually set the buffer size using the followingenvironment variables.
TPU_PREMAPPED_BUFFER_SIZE:Setsthetotalsize(inbytes)ofthepre-mappedbufferregion.TPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES:Setsthemaximumsizeofasinglebufferthatcanbeallocatedfromthepre-mappedregion.For example, you can:
exportTPU_PREMAPPED_BUFFER_SIZE=4294967296to set the buffer size and:
exportTPU_PREMAPPED_BUFFER_TRANSFER_THRESHOLD_BYTES```toenableit.Thisexportsetsthesizetothedefault.Guidance: Adjust the value of TPU_PREMAPPED_BUFFER_SIZE if yoususpect host-device data transferis a bottleneck. Monitor host memory usage and model performance to findan optimal balance. The default value is typically sufficient for mostuse cases.
XLA compiler optimizations
XLA is a compiler for machinelearning that can produce binaries for TPUs, CPUs, GPUs and other platforms.While XLA is part of the standard TensorFlow codebase, it can also be used onPyTorch andJAX models. Modelsfor Cloud TPU are translated to an XLA graph, which XLA then compiles to a TPUexecutable. For more information about XLA, seeXLA: Optimizing Compiler for Machine Learning.
Padding
To use TPU memory efficiently, structure your data so that it can be tiled into128 x 8 chunks. When the data for a matrix computation does not fill an entire128 x 8 chunk, the XLA compiler pads tensors. There are two drawbacks to padding:
- Padded tensors under-utilize the TPU core.
- Padding increases the amount of on-chip memory storage required for a tensorand can lead to an out-of-memory error.
While padding is automatically performed by the XLA compiler when necessary, youcan determine the amount of padding performed using the memory viewer tool. You canavoid padding by picking tensor dimensions that are well suited for TPU.
Tensor dimensions
To achieve peak FLOPs, dimensions of matrix multiplication should be larger than the MXU size for the TPU version you are using. MXU size is 256 x 256 for v6e and 128 x 128 for versions prior to v6e. For more information, seeCloud TPU system architecture.
Batch size
The XLA compiler rounds up the sizes of tensors stored in TPU HBM memory toperform computations more efficiently. This padding happens transparently at thehardware level and does not affect results. However, in certain cases thepadding can result in significantly increased memory use and execution time.
The TPU runtime lays out tensors in memory to maximize computational efficiencyand minimize padding. To minimize memory overhead and maximize computationalefficiency,one of the following must be true:
The total batch size should be a multiple of 64 (8 per TPU core), and featuredimension sizes should be a multiple of 128.
The total batch size should be a multiple of 1024 (128 per TPU core), andfeature dimension sizes should be a multiple of 8.
Using a batch size of 1024 and feature dimensions that are a multiple of 128results in the best efficiency, although this may not be possible for all models.
Note:Feature dimension refers to the hidden size of a fully-connected layeror the number of output channels in a convolution. Not all layers can conform tothis rule, especially the first and last layers of the network. This is fine,most models require some amount of padding.Fusion
Fusion is a general technique the XLA compiler uses to optimize programs. Afused operation is the combination of multiple constituent operations that areto be executed in combination.
For example, consider the following series of operations:
tmp=tf.add(x,y)result=tf.multiply(tmp,z)This code is roughly equivalent to the following pseudo code:
for (i = 0; i < element_count; i++) { tmp[i] = x[i] + y[i]; } for (i = 0; i < element_count; i++) { result[i] = tmp[i] * z[i]; }With fusion, the array accesses happen at the same time:
for (i = 0; i < element_count; i++) { result[i] = (x[i] + y[i]) * z[i]; }In this example, the number of memory round trips is reduced and XLA does notneed to allocate any space for 'tmp'.
Fusion is a critical optimization and benefits the Cloud TPU inseveral ways:
- It reduces memory transfers by removing the need to store intermediateresults in main memory, which is slow.
- It allows greater utilization of hardware units which would otherwise beunutilized.
- It can reduce the memory utilization of a model as fewer buffers need to belive at the same time.
Broadcasting
Broadcasting implicitly occurs when two tensors with different, but compatible,shapes are combined.
For example,tf.add(vector, matrix) requires the vector to be broadcasted tothe shape of the matrix. The result of the operation has the same shape as thematrix. For more details, see the guide tobroadcasting arrays.
While broadcasts can often be fused with their consumers, forcing a broadcastmay result in poor performance and increased memory usage.
In the following example, the broadcast implicit in the addition of a vector andmatrix cannot be fused with the argmax resulting in a materialized broadcast:
`tf.argmax(tf.add(vector, zero_matrix), axis=0)`Performance recommendations for the Ironwood dual-chiplet architecture
The Ironwood programming model lets you access two TPU devices instead of thesingle logical core (also known asMegaCore)architecture used in previous generations (TPU v4 and v5p). This change improvesthe cost-effectiveness and efficiency of manufacturing the chip. While thisrepresents an architectural shift, the new design ensures that you can reuseexisting software models with minimal changes.
To achieve the best performance with the dual-chiplet architecture, werecommend the following approaches:
Use tensor parallelism across chiplets: The high bandwidth D2D interfaceis designed for efficient tensor parallelism. We recommend splitting tensorsacross the two on-chip devices.
Utilize hierarchical collectives: To maximize communication efficiency,take advantage of the two-level network hierarchy: the ultra-fast D2D linkbetween on-chip chiplets and the fast ICI links within a slice. When usingautomatic parallelism with SPMD (single program, multiple data), the XLAcompiler handles this for you by automatically generating hierarchicalcollective operations. When manually partitioning your model, you shouldalso design your communication patterns around this hierarchy. Prioritizecommunication between the two devices on the same chip before communicatingwith devices on other chips.
Overlap communication with computation: To maximize hardwareutilization, offload collective communication operations, such asall-reduce, to the SparseCores. These operations, which aren't bound to thematrix-multiply unit (MXU), can execute on the SparseCores concurrentlywhile the TensorCores continue their computation. This technique can recoversome of the performance benefits that were inherent to the fused operationsin the previous MegaCore architecture.
Offload to SparseCore for embeddings: In the dual-chiplet design,embedding tables could be partitioned across the HBM of both chiplets. Toavoid performance degradation from this lack of sharing memory, offloadembedding gather operations to the SparseCore. This strategy utilizes thehigh-speed D2D interconnect to efficiently transfer embedding vectorsbetween the chiplets. For more information about SparseCore and embeddingmodels, seeA deep dive into SparseCore for Large Embedding Models(LEM).
For more information about the Ironwood architecture in TPU7x, seeTPU7x (Ironwood).
Except as otherwise noted, the content of this page is licensed under theCreative Commons Attribution 4.0 License, and code samples are licensed under theApache 2.0 License. For details, see theGoogle Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2025-11-24 UTC.