Benchmarking JAX code
Benchmarking JAX code#
You just ported a tricky function from NumPy/SciPy to JAX. Did that actuallyspeed things up?
Keep in mind these important differences from NumPy when measuring thespeed of code using JAX:
JAX code is Just-In-Time (JIT) compiled. Most code written in JAX can bewritten in such a way that it supports JIT compilation, which can make it runmuch faster (seeTo JIT or not to JIT).To get maximum performance from JAX, you should apply
jax.jit()on yourouter-most function calls.Keep in mind that the first time you run JAX code, it will be slower becauseit is being compiled. This is true even if you don’t use
jitin your owncode, because JAX’s builtin functions are also JIT compiled.JAX has asynchronous dispatch. This means that you need to call
.block_until_ready()to ensure that computation has actually happened(seeAsynchronous dispatch).JAX by default only uses 32-bit dtypes. You may want to either explicitlyuse 32-bit dtypes in NumPy or enable 64-bit dtypes in JAX (seeDouble (64 bit) precision)for a fair comparison.
Transferring data between CPUs and accelerators takes time. If you onlywant to measure how long it takes to evaluate a function, you may want totransfer data to the device on which you want to run it first (seeControlling data and computation placement on devices).
Here’s an example of how to put together all these tricks into a microbenchmarkfor comparing JAX versus NumPy, making using of IPython’s convenient%time and%timeit magics:
importnumpyasnpimportjaxdeff(x):# function we're benchmarking (works in both NumPy & JAX)returnx.T@(x-x.mean(axis=0))x_np=np.ones((1000,1000),dtype=np.float32)# same as JAX default dtype%timeitf(x_np)# measure NumPy runtime# measure JAX device transfer time%timex_jax=jax.device_put(x_np).block_until_ready()f_jit=jax.jit(f)%timef_jit(x_jax).block_until_ready()# measure JAX compilation time%timeitf_jit(x_jax).block_until_ready()# measure JAX runtime
When run with a GPU inColab, we see:
NumPy takes 16.2 ms per evaluation on the CPU
JAX takes 1.26 ms to copy the NumPy arrays onto the GPU
JAX takes 193 ms to compile the function
JAX takes 485 µs per evaluation on the GPU
In this case, we see that once the data is transferred and the function iscompiled, JAX on the GPU is about 30x faster for repeated evaluations.
Is this a fair comparison? Maybe. The performance that ultimately matters is forrunning full applications, which inevitably include some amount of both datatransfer and compilation. Also, we were careful to pick large enough arrays(1000x1000) and an intensive enough computation (the@ operator isperforming matrix-matrix multiplication) to amortize the increased overhead ofJAX/accelerators vs NumPy/CPU. For example, if we switch this example to use10x10 input instead, JAX/GPU runs 10x slower than NumPy/CPU (100 µs vs 10 µs).
