Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up

jax-triton contains integrations between JAX and OpenAI Triton

License

NotificationsYou must be signed in to change notification settings

hawkinsp/jax-triton

 
 

Repository files navigation

PyPI version

Thejax-triton repository contains integrations betweenJAX andTriton.

Documentation can be foundhere.

This is not an officially supported Google product.

Quickstart

The main function of interest isjax_triton.triton_call for applying Tritonfunctions to JAX arrays, including insidejax.jit-compiled functions. Forexample, we can definea kernel from the Tritontutorial:

importtritonimporttriton.languageastl@triton.jitdefadd_kernel(x_ptr,y_ptr,length,output_ptr,block_size:tl.constexpr,):"""Adds two vectors."""pid=tl.program_id(axis=0)block_start=pid*block_sizeoffsets=block_start+tl.arange(0,block_size)mask=offsets<lengthx=tl.load(x_ptr+offsets,mask=mask)y=tl.load(y_ptr+offsets,mask=mask)output=x+ytl.store(output_ptr+offsets,output,mask=mask)

Then we can apply it to JAX arrays usingjax_triton.triton_call:

importjaximportjax.numpyasjnpimportjax_tritonasjtdefadd(x:jnp.ndarray,y:jnp.ndarray)->jnp.ndarray:out_shape=jax.ShapeDtypeStruct(shape=x.shape,dtype=x.dtype)block_size=8returnjt.triton_call(x,y,x.size,kernel=add_kernel,out_shape=out_shape,grid=(x.size//block_size,),block_size=block_size)x_val=jnp.arange(8)y_val=jnp.arange(8,16)print(add(x_val,y_val))print(jax.jit(add)(x_val,y_val))

Seethe examplesdirectory, especiallyfused_attention.pyandthe fused attentionipynb.

Installation

$ pip install jax-triton

Make sure you have a CUDA-compatiblejaxlib installed.For example you could run:

$ pip install"jax[cuda11_cudnn82]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Installation at HEAD

JAX-Triton and Pallas are developed at JAX and Jaxlib HEAD and close to Triton HEAD. To get a bleeding edge installation of JAX-Triton, run:

$ pip install'jax-triton @ git+https://github.com/jax-ml/jax-triton.git'

This should install compatible versions of JAX and Triton.

JAX-Triton does depend on Jaxlib but it's usually a more stable dependency. You might be able to get away with using a recent jaxlib release:

$ pip install jaxlib[cuda11_pip]$# or$ pip install jaxlib[cuda12_pip]

If you find there are issues with the latest Jaxlib release, you can try using a Jaxlib nightly.To install a new jaxlib, you can find a link to aCUDA 11 nightly orCUDA 12 nightly. Then install it via:

$ pip install'jaxlib @ <link to nightly>'

or to install CUDA via pip automatically, you can do:

$ pip install'jaxlib[cuda11_pip] @ <link to nightly>'$# or$ pip install'jaxlib[cuda12_pip] @ <link to nightly>'

Development

To developjax-triton, you can clone the repo with:

$ git clone https://github.com/jax-ml/jax-triton.git

and do an editable install with:

$cd jax-triton$ pip install -e.

To run thejax-triton tests, you'll needpytest:

$ pip install pytest$ pytest tests/

About

jax-triton contains integrations between JAX and OpenAI Triton

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python100.0%

[8]ページ先頭

©2009-2025 Movatter.jp