Contributing to JAX
Contents
Contributing to JAX#
Everyone can contribute to JAX, and we value everyone’s contributions. There are severalways to contribute, including:
Answering questions on JAX’sdiscussions page
Improving or expanding JAX’sdocumentation
Contributing to JAX’scode-base
Contributing in any of the above ways to the broader ecosystem oflibraries built on JAX
The JAX project followsGoogle’s Open Source Community Guidelines.
Ways to contribute#
We welcome pull requests, in particular for those issues marked withcontributions welcome orgood first issue.
For other proposals, we ask that you first open a GitHubIssue orDiscussionto seek feedback on your planned contribution.
Can I contribute AI generated code?#
All submissions to Google Open Source projects need to follow Google’sContributor LicenseAgreement (CLA), in which contributors agree that theircontribution is an original work of authorship. This doesn’t prohibit the use of codingassistance tools, but what’s submitted does need to be a contributor’s original creation.
In the JAX project, a main concern with AI-generated contributions is thatlow-quality AI-generated code imposes a disproportionate review cost.Since the team’s capacity for code review is limited, we have a higher barfor accepting AI-generated contributions compared to those written by a human.
A loose rule of thumb: if the team needs to spend more time reviewing acontribution than the contributor spends generating it, then the contributionis probably not helpful to the project, and we will likely reject it.
Contributing code using pull requests#
We do all of our development using git, so basic knowledge is assumed.
Follow these steps to contribute code:
Sign theGoogle Contributor License Agreement (CLA).For more information, see theJAX pull request checklist below.
Fork the JAX repository by clicking theFork button on therepository page. This createsa copy of the JAX repository in your own account.
Install Python >= 3.11 locally in order to run tests.
pipinstalling your fork from source. This allows you to modify the codeand immediately test it out:gitclonehttps://github.com/YOUR_USERNAME/jaxcdjaxpipinstall-rbuild/test-requirements.txt# Installs all testing requirements.pipinstall-e".[cpu]"# Installs JAX from the current directory in editable mode.
Add the JAX repo as an upstream remote, so you can use it to sync yourchanges.
gitremoteaddupstreamhttps://www.github.com/jax-ml/jax
Create a branch where you will develop from:
gitcheckout-bname-of-change
And implement your changes using your favorite editor (we recommendVisual Studio Code).
Make sure your code passes JAX’s lint and type checks, by running the following fromthe top of the repository:
pipinstallpre-commitpre-commitrun--all
SeeLinting and type-checking for more details.
Make sure the tests pass by running the following command from the top ofthe repository:
pytest-nautotests/
Run them in 64-bit mode as well, by setting the environment variable
JAX_ENABLE_X64=True:JAX_ENABLE_X64=Truepytest-nautotests/
JAX’s test suite is quite large, so if you know the specific test file that covers yourchanges, you can limit the tests to that; for example:
pytest-nautotests/lax_scipy_test.py
You can narrow the tests further by using the
pytest-kflag to match particular testnames:pytest-nautotests/lax_scipy_test.py-ktestLogSumExp
JAX also offers more fine-grained control over which particular tests are run;seeRunning the tests for more information.
Once you are satisfied with your change, create a commit as follows (how to write a commit message):
gitaddfile1.pyfile2.py...gitcommit-m"Your commit message"Then sync your code with the main repo:
gitfetchupstreamgitrebaseupstream/main
Finally, push your commit on your development branch and create a remotebranch in your fork that you can use to create a pull request from:
gitpush--set-upstreamoriginname-of-change
Please ensure your contribution is a single commit (seeSingle-change commits and pull requests)
Create a pull request from the JAX repository and send it for review.Check theJAX pull request checklist for considerations when preparing your PR, andconsultGitHub Helpif you need more information on using pull requests.
JAX pull request checklist#
As you prepare a JAX pull request, here are a few things to keep in mind:
Google contributor license agreement#
Contributions to this project must be accompanied by a Google Contributor LicenseAgreement (CLA). You (or your employer) retain the copyright to your contribution;this simply gives us permission to use and redistribute your contributions aspart of the project. Head over tohttps://cla.developers.google.com/ to seeyour current agreements on file or to sign a new one.
You generally only need to submit a CLA once, so if you’ve already submitted one(even if it was for a different project), you probably don’t need to do itagain. If you’re not certain whether you’ve signed a CLA, you can open your PRand our friendly CI bot will check for you.
Single-change commits and pull requests#
A git commit ought to be a self-contained, single change with a descriptivemessage. This helps with review and with identifying or reverting changes ifissues are uncovered later on.
Pull requests typically comprise a single git commit. (In some cases, forinstance for large refactors or internal rewrites, they may contain several.)In preparing a pull request for review, you may need to squash togethermultiple commits. We ask that you do this prior to sending the PR for review ifpossible. Thegitrebase-i command might be useful to this end.
Linting and type-checking#
JAX usesmypy andruff to statically test code quality; theeasiest way to run these checks locally is via thepre-commit framework:
pipinstallpre-commitpre-commitrun--all-files
If your pull request touches documentation notebooks, this will also run some checkson those (SeeUpdate notebooks for more details).
Full GitHub test suite#
Your PR will automatically be run through a full test suite on GitHub CI, whichcovers a range of Python versions, dependency versions, and configuration options.It’s normal for these tests to turn up failures that you didn’t catch locally; tofix the issues you can push new commits to your branch.
Restricted test suite#
Once your PR has been reviewed, a JAX maintainer will mark it aspullready. Thiswill trigger a larger set of tests, including tests on GPU and TPU backends that arenot available via standard GitHub CI. Detailed results of these tests are not publiclyviewable, but the JAX maintainer assigned to your PR will communicate with you regardingany failures these might uncover; it’s not uncommon, for example, that numerical testsneed different tolerances on TPU than on CPU.
Wheel sources update#
If a new python package or a new file is added to the wheel, one of thefollowing Bazel targets should be updated:
A static source addition: add to
static_srcslist.Example: add
//:file.txttojaxwheel.wheel_sources(name="jax_sources",data_srcs=[...],py_srcs=[...],static_srcs=[..."//:file.txt"],)
A platform-dependent source addition: add to
data_srcslist.Example: add a
cc_librarytarget//:cc_targettojaxwheel.wheel_sources(name="jax_sources",data_srcs=[..."//:cc_target"],py_srcs=[...],static_srcs=[...],)
If the existing targets in
data_srcsalready have a transitivedependency on//:cc_target, you don’t need to add it explicitly.A new python package addition: create
__init__.pyfile and Bazel pythonrule target with__init__.pyin sources, add it topy_srcslist.Example: add a new package
jax.test_packagetojaxwheel:The content of the file
jax/test_package/BUILD:pytype_strict_library(name="init",srcs=["__init__.py"],visibility=["//visibility:public"],)
wheel_sources(name="jax_sources",data_srcs=[...],py_srcs=[..."//jax/test_package:init",],static_srcs=[...],)
A new python source addition to existing package: create/update Bazel pythonrule target with the new file in sources, add it to
py_srcslist.Example: add a new file
jax/test_package/example.pytojaxwheel:The content of the file
jax/test_package/BUILD:pytype_strict_library(name="example",srcs=["__init__.py","example.py"],visibility=["//visibility:public"],)
wheel_sources(name="jax_sources",data_srcs=[...],py_srcs=[..."//jax/test_package:example",],static_srcs=[...],)
If the existing targets in
py_srcsalready have a transitivedependency onexample.py, you don’t need to add it explicitly.
