- Notifications
You must be signed in to change notification settings - Fork0
jax-triton contains integrations between JAX and OpenAI Triton
License
hawkinsp/jax-triton
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
Thejax-triton
repository contains integrations betweenJAX andTriton.
Documentation can be foundhere.
This is not an officially supported Google product.
The main function of interest isjax_triton.triton_call
for applying Tritonfunctions to JAX arrays, including insidejax.jit
-compiled functions. Forexample, we can definea kernel from the Tritontutorial:
importtritonimporttriton.languageastl@triton.jitdefadd_kernel(x_ptr,y_ptr,length,output_ptr,block_size:tl.constexpr,):"""Adds two vectors."""pid=tl.program_id(axis=0)block_start=pid*block_sizeoffsets=block_start+tl.arange(0,block_size)mask=offsets<lengthx=tl.load(x_ptr+offsets,mask=mask)y=tl.load(y_ptr+offsets,mask=mask)output=x+ytl.store(output_ptr+offsets,output,mask=mask)
Then we can apply it to JAX arrays usingjax_triton.triton_call
:
importjaximportjax.numpyasjnpimportjax_tritonasjtdefadd(x:jnp.ndarray,y:jnp.ndarray)->jnp.ndarray:out_shape=jax.ShapeDtypeStruct(shape=x.shape,dtype=x.dtype)block_size=8returnjt.triton_call(x,y,x.size,kernel=add_kernel,out_shape=out_shape,grid=(x.size//block_size,),block_size=block_size)x_val=jnp.arange(8)y_val=jnp.arange(8,16)print(add(x_val,y_val))print(jax.jit(add)(x_val,y_val))
Seethe examplesdirectory, especiallyfused_attention.pyandthe fused attentionipynb.
$ pip install jax-triton
Make sure you have a CUDA-compatiblejaxlib
installed.For example you could run:
$ pip install"jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
JAX-Triton and Pallas are developed at JAX and Jaxlib HEAD and close to Triton HEAD. To get a bleeding edge installation of JAX-Triton, run:
$ pip install'jax-triton @ git+https://github.com/jax-ml/jax-triton.git'
This should install compatible versions of JAX and Triton.
JAX-Triton does depend on Jaxlib but it's usually a more stable dependency. You might be able to get away with using a recent jaxlib release:
$ pip install jaxlib[cuda11_pip]$# or$ pip install jaxlib[cuda12_pip]
If you find there are issues with the latest Jaxlib release, you can try using a Jaxlib nightly.To install a new jaxlib, you can find a link to aCUDA 11 nightly orCUDA 12 nightly. Then install it via:
$ pip install'jaxlib @ <link to nightly>'
or to install CUDA via pip automatically, you can do:
$ pip install'jaxlib[cuda11_pip] @ <link to nightly>'$# or$ pip install'jaxlib[cuda12_pip] @ <link to nightly>'
To developjax-triton
, you can clone the repo with:
$ git clone https://github.com/jax-ml/jax-triton.git
and do an editable install with:
$cd jax-triton$ pip install -e.
To run thejax-triton
tests, you'll needpytest
:
$ pip install pytest$ pytest tests/