Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Contributing to JAX#

Everyone can contribute to JAX, and we value everyone’s contributions. There are severalways to contribute, including:

The JAX project followsGoogle’s Open Source Community Guidelines.

Ways to contribute#

We welcome pull requests, in particular for those issues marked withcontributions welcome orgood first issue.

For other proposals, we ask that you first open a GitHubIssue orDiscussionto seek feedback on your planned contribution.

Can I contribute AI generated code?#

All submissions to Google Open Source projects need to follow Google’sContributor LicenseAgreement (CLA), in which contributors agree that theircontribution is an original work of authorship. This doesn’t prohibit the use of codingassistance tools, but what’s submitted does need to be a contributor’s original creation.

In the JAX project, a main concern with AI-generated contributions is thatlow-quality AI-generated code imposes a disproportionate review cost.Since the team’s capacity for code review is limited, we have a higher barfor accepting AI-generated contributions compared to those written by a human.

A loose rule of thumb: if the team needs to spend more time reviewing acontribution than the contributor spends generating it, then the contributionis probably not helpful to the project, and we will likely reject it.

Contributing code using pull requests#

We do all of our development using git, so basic knowledge is assumed.

Follow these steps to contribute code:

  1. Sign theGoogle Contributor License Agreement (CLA).For more information, see theJAX pull request checklist below.

  2. Fork the JAX repository by clicking theFork button on therepository page. This createsa copy of the JAX repository in your own account.

  3. Install Python >= 3.11 locally in order to run tests.

  4. pip installing your fork from source. This allows you to modify the codeand immediately test it out:

    gitclonehttps://github.com/YOUR_USERNAME/jaxcdjaxpipinstall-rbuild/test-requirements.txt# Installs all testing requirements.pipinstall-e".[cpu]"# Installs JAX from the current directory in editable mode.
  5. Add the JAX repo as an upstream remote, so you can use it to sync yourchanges.

    gitremoteaddupstreamhttps://www.github.com/jax-ml/jax
  6. Create a branch where you will develop from:

    gitcheckout-bname-of-change

    And implement your changes using your favorite editor (we recommendVisual Studio Code).

  7. Make sure your code passes JAX’s lint and type checks, by running the following fromthe top of the repository:

    pipinstallpre-commitpre-commitrun--all

    SeeLinting and type-checking for more details.

  8. Make sure the tests pass by running the following command from the top ofthe repository:

    pytest-nautotests/

    Run them in 64-bit mode as well, by setting the environment variableJAX_ENABLE_X64=True:

    JAX_ENABLE_X64=Truepytest-nautotests/

    JAX’s test suite is quite large, so if you know the specific test file that covers yourchanges, you can limit the tests to that; for example:

    pytest-nautotests/lax_scipy_test.py

    You can narrow the tests further by using thepytest-k flag to match particular testnames:

    pytest-nautotests/lax_scipy_test.py-ktestLogSumExp

    JAX also offers more fine-grained control over which particular tests are run;seeRunning the tests for more information.

  9. Once you are satisfied with your change, create a commit as follows (how to write a commit message):

    gitaddfile1.pyfile2.py...gitcommit-m"Your commit message"

    Then sync your code with the main repo:

    gitfetchupstreamgitrebaseupstream/main

    Finally, push your commit on your development branch and create a remotebranch in your fork that you can use to create a pull request from:

    gitpush--set-upstreamoriginname-of-change

    Please ensure your contribution is a single commit (seeSingle-change commits and pull requests)

  10. Create a pull request from the JAX repository and send it for review.Check theJAX pull request checklist for considerations when preparing your PR, andconsultGitHub Helpif you need more information on using pull requests.

JAX pull request checklist#

As you prepare a JAX pull request, here are a few things to keep in mind:

Google contributor license agreement#

Contributions to this project must be accompanied by a Google Contributor LicenseAgreement (CLA). You (or your employer) retain the copyright to your contribution;this simply gives us permission to use and redistribute your contributions aspart of the project. Head over tohttps://cla.developers.google.com/ to seeyour current agreements on file or to sign a new one.

You generally only need to submit a CLA once, so if you’ve already submitted one(even if it was for a different project), you probably don’t need to do itagain. If you’re not certain whether you’ve signed a CLA, you can open your PRand our friendly CI bot will check for you.

Single-change commits and pull requests#

A git commit ought to be a self-contained, single change with a descriptivemessage. This helps with review and with identifying or reverting changes ifissues are uncovered later on.

Pull requests typically comprise a single git commit. (In some cases, forinstance for large refactors or internal rewrites, they may contain several.)In preparing a pull request for review, you may need to squash togethermultiple commits. We ask that you do this prior to sending the PR for review ifpossible. Thegitrebase-i command might be useful to this end.

Linting and type-checking#

JAX usesmypy andruff to statically test code quality; theeasiest way to run these checks locally is via thepre-commit framework:

pipinstallpre-commitpre-commitrun--all-files

If your pull request touches documentation notebooks, this will also run some checkson those (SeeUpdate notebooks for more details).

Full GitHub test suite#

Your PR will automatically be run through a full test suite on GitHub CI, whichcovers a range of Python versions, dependency versions, and configuration options.It’s normal for these tests to turn up failures that you didn’t catch locally; tofix the issues you can push new commits to your branch.

Restricted test suite#

Once your PR has been reviewed, a JAX maintainer will mark it aspullready. Thiswill trigger a larger set of tests, including tests on GPU and TPU backends that arenot available via standard GitHub CI. Detailed results of these tests are not publiclyviewable, but the JAX maintainer assigned to your PR will communicate with you regardingany failures these might uncover; it’s not uncommon, for example, that numerical testsneed different tolerances on TPU than on CPU.

Wheel sources update#

If a new python package or a new file is added to the wheel, one of thefollowing Bazel targets should be updated:

jax wheel sources

jaxlib wheel sources

jax CUDA plugin wheel sources

jax CUDA pjrt wheel sources

  1. A static source addition: add tostatic_srcs list.

    Example: add//:file.txt tojax wheel.

    wheel_sources(name="jax_sources",data_srcs=[...],py_srcs=[...],static_srcs=[..."//:file.txt"],)
  2. A platform-dependent source addition: add todata_srcs list.

    Example: add acc_library target//:cc_target tojax wheel.

    wheel_sources(name="jax_sources",data_srcs=[..."//:cc_target"],py_srcs=[...],static_srcs=[...],)

    If the existing targets indata_srcs already have a transitivedependency on//:cc_target, you don’t need to add it explicitly.

  3. A new python package addition: create__init__.py file and Bazel pythonrule target with__init__.py in sources, add it topy_srcs list.

    Example: add a new packagejax.test_package tojax wheel:

    The content of the filejax/test_package/BUILD:

    pytype_strict_library(name="init",srcs=["__init__.py"],visibility=["//visibility:public"],)
    wheel_sources(name="jax_sources",data_srcs=[...],py_srcs=[..."//jax/test_package:init",],static_srcs=[...],)
  4. A new python source addition to existing package: create/update Bazel pythonrule target with the new file in sources, add it topy_srcs list.

    Example: add a new filejax/test_package/example.py tojax wheel:

    The content of the filejax/test_package/BUILD:

    pytype_strict_library(name="example",srcs=["__init__.py","example.py"],visibility=["//visibility:public"],)
    wheel_sources(name="jax_sources",data_srcs=[...],py_srcs=[..."//jax/test_package:example",],static_srcs=[...],)

    If the existing targets inpy_srcs already have a transitivedependency onexample.py, you don’t need to add it explicitly.


[8]ページ先頭

©2009-2025 Movatter.jp