Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Persistent compilation cache#

JAX has an optional disk cache for compiled programs. If enabled, JAX willstore copies of compiled programs on disk, which can save recompilation timewhen running the same or similar tasks repeatedly.

Note: if the compilation cache is not on a local filesystem,etils needs to be installed.

pipinstalletils

Usage#

Quick start#

importjaximportjax.numpyasjnpjax.config.update("jax_compilation_cache_dir","/tmp/jax_cache")jax.config.update("jax_persistent_cache_min_entry_size_bytes",-1)jax.config.update("jax_persistent_cache_min_compile_time_secs",0)jax.config.update("jax_persistent_cache_enable_xla_caches","xla_gpu_per_fusion_autotune_cache_dir")@jax.jitdeff(x):returnx+1x=jnp.zeros((2,2))f(x)

Setting cache directory#

The compilation cache is enabled when thecache locationis set. This should be done prior to the first compilation. Set the location asfollows:

(1) Using environment variable

In shell, before running the script:

exportJAX_COMPILATION_CACHE_DIR="/tmp/jax_cache"

Or on the top of the Python script:

importosos.environ["JAX_COMPILATION_CACHE_DIR"]="/tmp/jax_cache"

(2) Usingjax.config.update()

jax.config.update("jax_compilation_cache_dir","/tmp/jax_cache")

(3) Usingset_cache_dir()

fromjax.experimental.compilation_cacheimportcompilation_cacheascccc.set_cache_dir("/tmp/jax_cache")

Caching thresholds#

  • jax_persistent_cache_min_compile_time_secs: A computation will only bewritten to the persistent cache if the compilation time is longer thanthe specified value. It is defaulted to 1.0 second.

  • jax_persistent_cache_min_entry_size_bytes: The minimum size (in bytes)of an entry that will be cached in the persistent compilation cache:

    • -1: disable the size restriction and prevent overrides.

    • Leave at default (0) to allow for overrides. The override willtypically ensure that the minimum size is optimal for the file systembeing used for the cache.

    • >0: the actual minimum size desired; no overrides.

Note that both criteria need to be satisfied for a function to be cached.

Additional caching#

XLA supports additional caching mechanism which can be enabled alongside JAX’spersistent compilation cache to further improve recompilation time.

  • jax_persistent_cache_enable_xla_caches: Possible values:

    • all: enable all XLA caching features

    • none: don’t enable any extra XLA caching features

    • xla_gpu_kernel_cache_file: only enable the kernel cache

    • xla_gpu_per_fusion_autotune_cache_dir: (default value) only enable theautotuning cache

Google Cloud#

When running on Google Cloud, the compilation cache can be placed on a GoogleCloud Storage (GCS) bucket. We recommend the following configuration:

  • Create the bucket in the same region as where the workload will run.

  • Create the bucket in the same project as the workload’s VM(s). Ensure thatpermissions are set so that the VM(s) can write to the bucket.

  • There is no need for replication for smaller workloads. Larger workloadscould benefit from replication.

  • Use “Standard” for the default storage class for the bucket.

  • Set the soft delete policy to its shortest: 7 days.

  • Set the object lifecycle to the expected duration of the workload run.For example, if the workload is expected to run for 10 days, set the objectlifecycle to 10 days. That should cover restarts that occur during the entirerun. Useage for the lifecycle condition andDelete for the action. SeeObject Lifecycle Managementfor details. If the object lifecycle is not set, the cache will continue togrow since there is no eviction mechanism implemented.

  • All encryption policies are supported.

It isrecommended to useGoogle Cloud Storage Fuseto mount the GCS bucket as a local directory. This is because when running JAXin a multi-node setup, multiple nodes might try to write to the cachesimultaneously, leading to GCS rate-limit errors. GCSFuse handles this byensuring that only one process can write to a file at a time, preventing theseerrors.

To set up GCSFuse, follow instructions forGCE orGKE.For better performance, enable file caching(GCE andGKE).

Once GCSFuse is configured, set the JAX cache directory to the GCSFuse mountpoint:

# Example assuming the GCS bucket is mounted at /gcs/my-bucketjax.config.update("jax_compilation_cache_dir","/gcs/my-bucket/jax-cache")

Direct GCS access :

If you choose not to use GCSFuse, you can point the cache directly to a GCSbucket.

Assuming thatgs://jax-cache is the GCS bucket, set cache location asfollows:

jax.config.update("jax_compilation_cache_dir","gs://jax-cache")

How it works#

The cache key is the signature for a compiled function containing thefollowing parameters:

  • The computation performed by the function captured by the non-optimized HLO of the JAX function being hashed

  • The jaxlib version

  • Relevant XLA compilation flags

  • Device configuration captured in general, by the number of devices and the topology of the devices.Currently for GPUs, the topology only contains a string representation of the GPU name

  • Compression algorithm used to compress the compiled executable

  • A string produced byjax._src.cache_key.custom_hook(). This function canbe reassigned to a user-defined function, so that the resulting string canbe altered. By default, this function always returns an empty string.

Caching on multiple nodes#

The first time a program is run (the persistent cache is cold / empty) all processes will compile,but only the process with rank 0 in the global communication group will write to the persistent cache.In subsequent runs, all processes will attempt to read from the persistent cache,so it is important for the persistent cache to be in a shared file system (eg: NFS) or remote storage (eg: GFS).If the persistent cache is local to rank 0, then all processes except rank 0 will once again compilein subsequent runs as a result of a compilation cache miss.

Pre-compiling multi-node programs on single node#

JAX can populate the compilation cache with compiled programs for multiple nodeson a single node. Preparing the cache on a single node helps to decrease the costlycompilation time on a cluster. To compile and run multi-node programs on a singlenode, users can create fake remote devices usingthejax_mock_gpu_topology configuration option.

For instance, the snippet below instructs JAX to mock a cluster with fournodes, each node running eight processes with each process attached to one GPU.

jax.config.update("jax_mock_gpu_topology","4x8x1")

After populating the cache with this config, users can run the programwithout recompilation on four nodes, eight processes per node,one GPU per process.

Important notes:

  • The process running the mocked program must have the same amount of GPUsand the same GPU model as the nodes that would use the cache. For instance,a mocked topology8x4x2 must run in a process with two GPUs.

  • When running programs with mocked topology, the results of communicationswith other nodes are undefined, so the outputs of JAX programs runningin mocked environments will likely be incorrect.

Logging cache activity#

It can be helpful to examine what exactly is happening with the persistent compilation cache for debugging.Here are a few suggestions on how to begin.

Users can enable the logging of related source files by placing

importosos.environ["JAX_DEBUG_LOG_MODULES"]="jax._src.compiler,jax._src.lru_cache"

on the top of the script. Alternatively, you can change the global jax logging level with

importosos.environ["JAX_LOGGING_LEVEL"]="DEBUG"# or locally withjax.config.update("jax_logging_level","DEBUG")

Examining cache misses#

To examine and understand why there are cache misses, JAX includes a configuration flag thatenables the logging of all cache misses (including persistent compilation cache misses) with their explanations.Although currently, this is only implemented for tracing cache misses, the eventual goal is toexplain all cache misses. This can be enabled by setting the following configuration.

jax.config.update("jax_explain_cache_misses",True)

Pitfalls#

There are a couple of pitfalls that have currently been discovered:

  • Currently the persistent cache doesn’t work with function that have host callbacks. In this situation, caching is completely avoided.

    • This is because the HLO contains a pointer to the callback and changes from run to run even if the computation and compute infrastructure is exactly the same.

  • Currently the persistent cache doesn’t work with a function that uses primitives that implement their own custom_partitioning.

    • The HLO of the function contains a pointer to the custom_partitioning callback, and leads to different cache keys for the same computation across runs.

    • In this situation, caching still proceeds, but a different key is produced every time, making the cache ineffective.

Working aroundcustom_partitioning#

As mentioned, the compilation cache doesn’t work with a function that is composed of primitives that implementcustom_partitioning. However, it is possible to use shard_map to circumventcustom_partitioning for those primitives that do implement it and make the compilation cache work as expected:

Let’s pretend we have a functionF that implements a layernorm followed by a matrix multiplication using a primitiveLayerNorm that implementscustom_partitioning:

importjaxdefF(x1,x2,gamma,beta):ln_out=LayerNorm(x1,gamma,beta)returnln_out@x2

If we were to merely compile this function without shard_map, the cache key forlayernorm_matmul_without_shard_map would be different every time we ran the same code:

layernorm_matmul_without_shard_map=jax.jit(F,in_shardings=(...),out_sharding=(...))(x1,x2,gamma,beta)

However, if we were to wrap the layernorm primitive in shard_map and define a function G that performs the same computation, the cache key forlayernorm_matmul_with_shard_map will be the same every time despiteLayerNorm being implementingcustom_partitioning:

importjaxdefG(x1,x2,gamma,beta,mesh,ispecs,ospecs):ln_out=jax.shard_map(LayerNorm,mesh=mesh,in_specs=ispecs,out_specs=ospecs,check_vma=False)(x1,x2,gamma,beta)returnln_out@x2ispecs=jax.sharding.PartitionSpec(...)ospecs=jax.sharding.PartitionSpec(...)mesh=jax.sharding.Mesh(...)layernorm_matmul_with_shard_map=jax.jit(G,static_argnames=['mesh','ispecs','ospecs'])(x1,x2,gamma,beta,mesh,ispecs,ospecs)

Note that the primitive that implementscustom_partitioning must be wrapped in shard_map for this work around. It is insufficient to wrap the outer functionF in shard_map.


[8]ページ先頭

©2009-2026 Movatter.jp