Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Benchmarking JAX code

Benchmarking JAX code#

You just ported a tricky function from NumPy/SciPy to JAX. Did that actuallyspeed things up?

Keep in mind these important differences from NumPy when measuring thespeed of code using JAX:

  1. JAX code is Just-In-Time (JIT) compiled. Most code written in JAX can bewritten in such a way that it supports JIT compilation, which can make it runmuch faster (seeTo JIT or not to JIT).To get maximum performance from JAX, you should applyjax.jit() on yourouter-most function calls.

    Keep in mind that the first time you run JAX code, it will be slower becauseit is being compiled. This is true even if you don’t usejit in your owncode, because JAX’s builtin functions are also JIT compiled.

  2. JAX has asynchronous dispatch. This means that you need to call.block_until_ready() to ensure that computation has actually happened(seeAsynchronous dispatch).

  3. JAX by default only uses 32-bit dtypes. You may want to either explicitlyuse 32-bit dtypes in NumPy or enable 64-bit dtypes in JAX (seeDouble (64 bit) precision)for a fair comparison.

  4. Transferring data between CPUs and accelerators takes time. If you onlywant to measure how long it takes to evaluate a function, you may want totransfer data to the device on which you want to run it first (seeControlling data and computation placement on devices).

Here’s an example of how to put together all these tricks into a microbenchmarkfor comparing JAX versus NumPy, making using of IPython’s convenient%time and%timeit magics:

importnumpyasnpimportjaxdeff(x):# function we're benchmarking (works in both NumPy & JAX)returnx.T@(x-x.mean(axis=0))x_np=np.ones((1000,1000),dtype=np.float32)# same as JAX default dtype%timeitf(x_np)# measure NumPy runtime# measure JAX device transfer time%timex_jax=jax.device_put(x_np).block_until_ready()f_jit=jax.jit(f)%timef_jit(x_jax).block_until_ready()# measure JAX compilation time%timeitf_jit(x_jax).block_until_ready()# measure JAX runtime

When run with a GPU inColab, we see:

  • NumPy takes 16.2 ms per evaluation on the CPU

  • JAX takes 1.26 ms to copy the NumPy arrays onto the GPU

  • JAX takes 193 ms to compile the function

  • JAX takes 485 µs per evaluation on the GPU

In this case, we see that once the data is transferred and the function iscompiled, JAX on the GPU is about 30x faster for repeated evaluations.

Is this a fair comparison? Maybe. The performance that ultimately matters is forrunning full applications, which inevitably include some amount of both datatransfer and compilation. Also, we were careful to pick large enough arrays(1000x1000) and an intensive enough computation (the@ operator isperforming matrix-matrix multiplication) to amortize the increased overhead ofJAX/accelerators vs NumPy/CPU. For example, if we switch this example to use10x10 input instead, JAX/GPU runs 10x slower than NumPy/CPU (100 µs vs 10 µs).


[8]ページ先頭

©2009-2025 Movatter.jp