Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Installation#

Using JAX requires installing two packages:jax, which is pure Python andcross-platform, andjaxlib which contains compiled binaries, and requiresdifferent builds for different operating systems and accelerators.

Summary: For most users, a typical JAX installation may look something like this:

  • CPU-only (Linux/macOS/Windows)

    pipinstall-Ujax
  • GPU (NVIDIA, CUDA 13)

    pipinstall-U"jax[cuda13]"
  • TPU (Google Cloud TPU VM)

    pipinstall-U"jax[tpu]"

Supported platforms#

The table below shows all supported platforms and installation options. Check if your setup is supported; and if it says“yes” or“experimental”, then click on the corresponding link to learn how to install JAX in greater detail.

Linux, x86_64

Linux, aarch64

Mac, aarch64

Windows, x86_64

Windows WSL2, x86_64

CPU

yes

yes

yes

yes

yes

NVIDIA GPU

yes

yes

n/a

no

experimental

Google Cloud TPU

yes

n/a

n/a

n/a

n/a

AMD GPU

yes

no

n/a

no

experimental

Apple GPU

n/a

no

experimental

n/a

n/a

Intel GPU

experimental

n/a

n/a

no

no

CPU#

pip installation: CPU#

Currently, the JAX team releasesjaxlib wheels for the followingoperating systems and architectures:

  • Linux, x86_64

  • Linux, aarch64

  • macOS, Apple ARM-based

  • Windows, x86_64 (experimental)

To install a CPU-only version of JAX, which might be useful for doing localdevelopment on a laptop, you can run:

pipinstall--upgradepippipinstall--upgradejax

On Windows, you may also need to install theMicrosoft Visual Studio 2019 Redistributableif it is not already installed on your machine.

Other operating systems and architectures require building from source. Tryingto pip install on other operating systems and architectures may lead tojaxlibnot being installed alongsidejax, althoughjax may successfully install(but fail at runtime).

NVIDIA GPU#

On CUDA 12, JAX supports NVIDIA GPUs that have SM version 5.2 (Maxwell) or newer.Note that Kepler-series GPUs are no longer supported by JAX sinceNVIDIA has dropped support for Kepler GPUs in its software.On CUDA 13, JAX supports NVIDIA GPUs that have SM version 7.5 or newer. NVIDIAdropped support for previous GPUs in CUDA 13.

You must first install the NVIDIA driver. You’rerecommended to install the newest driver available from NVIDIA, but the driverversion must be >= 525 for CUDA 12 on Linux, and >= 580 for CUDA 13 on Linux.

If you need to use a newer CUDA toolkit with an older driver, for exampleon a cluster where you cannot update the NVIDIA driver easily, you may beable to use theCUDA forward compatibility packagesthat NVIDIA provides for this purpose.

pip installation: NVIDIA GPU (CUDA, installed via pip, easier)#

There are two ways to install JAX with NVIDIA GPU support:

  • Using NVIDIA CUDA and cuDNN installed from pip wheels

  • Using a self-installed CUDA/cuDNN

The JAX team strongly recommends installing CUDA and cuDNN using the pip wheels,since it is much easier!

NVIDIA has released CUDA packages only for x86_64 and aarch64.

pipinstall--upgradepip# NVIDIA CUDA 13 installation# Note: wheels only available on linux.pipinstall--upgrade"jax[cuda13]"# Alternatively, for CUDA 12, use# pip install --upgrade "jax[cuda12]"

We recommend migrating to the CUDA 13 wheels; at some point in the future wewill drop CUDA 12 support.

If JAX detects the wrong version of the NVIDIA CUDA libraries, there are several thingsyou need to check:

  • Make sure thatLD_LIBRARY_PATH is not set, sinceLD_LIBRARY_PATH canoverride the NVIDIA CUDA libraries.

  • Make sure that the NVIDIA CUDA libraries installed are those requested by JAX.Rerunning the installation command above should work.

pip installation: NVIDIA GPU (CUDA, installed locally, harder)#

If you prefer to use a preinstalled copy of NVIDIA CUDA, you must firstinstall NVIDIACUDA andcuDNN.

JAX provides pre-built CUDA-compatible wheels forLinux x86_64 and Linux aarch64 only. Othercombinations of operating system and architecture are possible, but requirebuilding from source (refer toBuilding from source to learn more}.

You should use an NVIDIA driver version that is at least as new as yourNVIDIA CUDA toolkit’s corresponding driver version.If you need to use a newer CUDA toolkit with an older driver, for exampleon a cluster where you cannot update the NVIDIA driver easily, you may beable to use theCUDA forward compatibility packagesthat NVIDIA provides for this purpose.

JAX currently ships two CUDA wheel variants: CUDA 12 and CUDA 13:

The CUDA 12 wheel is:

Built with

Compatible with

CUDA 12.3

CUDA >=12.1

CUDNN 9.8

CUDNN >=9.8, <10.0

NCCL 2.19

NCCL >=2.18

The CUDA 13 wheel is:

Built with

Compatible with

CUDA 13.0

CUDA >=13.0

CUDNN 9.8

CUDNN >=9.8, <10.0

NCCL 2.19

NCCL >=2.18

JAX checks the versions of your libraries, and will report an error if they arenot sufficiently new.Setting theJAX_SKIP_CUDA_CONSTRAINTS_CHECK environment variable will disablethe check, but using older versions of CUDA may lead to errors, or incorrectresults.

NCCL is an optional dependency, required only if you are performing multi-GPUcomputations.

To install, run:

pipinstall--upgradepip# Installs the wheel compatible with NVIDIA CUDA 13 and cuDNN 9.8 or newer.# Note: wheels only available on linux.pipinstall--upgrade"jax[cuda13-local]"# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.8 or newer.# Note: wheels only available on linux.# pip install --upgrade "jax[cuda12-local]"

Thesepip installations do not work with Windows, and may fail silently; refer to the tableabove.

You can find your CUDA version with the command:

nvcc--version

JAX usesLD_LIBRARY_PATH to find CUDA libraries andPATH to find binaries(ptxas,nvlink). Please make sure that these paths point to the correct CUDAinstallation.

JAX requires libdevice10.bc, which typically comes from the cuda-nvvm package.Make sure that it is present in your CUDA installation.

Please let the JAX team know onthe GitHub issue trackerif you run into any errors or problems with the pre-built wheels.

NVIDIA GPU Docker containers#

NVIDIA provides theJAXToolbox containers, which arebleeding edge containers containing nightly releases of jax and somemodels/frameworks.

Google Cloud TPU#

pip installation: Google Cloud TPU#

JAX provides pre-built wheels forGoogle Cloud TPU.To install JAX along with appropriate versions ofjaxlib andlibtpu, you can runthe following in your cloud TPU VM:

pipinstall"jax[tpu]"

For users of Colab (https://colab.research.google.com/), be sure you areusingTPU v2 and not the older, deprecated TPU runtime.

Mac GPU#

JAX is not supported on Mac/OSX GPU; instead use the standardCPU installation commands.

AMD GPU (Linux)#

AMD GPU support is provided by a ROCm JAX plugin supported by AMD.

There are several ways to use JAX on AMDGPU devices.Please seeAMD’s instructions for details.

Note: ROCm support on Windows WSL2 is experimental. For WSL installation, you may need to:

  1. InstallROCm for WSL following AMD’s official guide

  2. Follow the standard Linux ROCm JAX installation steps within your WSL environment

  3. Be aware that performance and stability may differ from native Linux installations

Intel GPU#

Intel provides an experimental OneAPI plugin: intel-extension-for-openxla for Intel GPU hardware. For more details and installation instructions, refer to one of the following two methods:

  1. Pip installation:JAX acceleration on Intel GPU.

  2. UsingIntel’s XLA Docker container.

Please report any issues related to:

Conda (community-supported)#

Conda installation#

There is a community-supported Conda build ofjax. To install it usingconda,simply run:

condainstalljax-cconda-forge

If you run this command on machine with an NVIDIA GPU, this should install a CUDA-enabled package ofjaxlib.

To ensure that the jax version you are installing is indeed CUDA-enabled, run:

condainstall"jaxlib=*=*cuda*"jax-cconda-forge

If you would like to override which release of CUDA is used by JAX, or toinstall the CUDA build on a machine without GPUs, follow the instructions in theTips & trickssection of theconda-forge website.

Go to theconda-forgejaxlib andjax repositoriesfor more details.

JAX nightly installation#

Nightly releases reflect the state of the main JAX repository at the time they arebuilt, and may not pass the full test suite.

Unlike the instructions for installing a JAX release, here we name all of JAX’spackages explicitly on the command line, sopip will upgrade them if a newerversion is available.

JAX publishes nightlies, release candidates(RCs), and releases to several non-pypiPEP 503 indexes.

All JAX packages can be reached from the indexhttps://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/as well as PyPI mirrored packages. This additional mirroring enables nightlyinstallation to use –index (-i) as the install method with pip.

Note: The unified index could return an RC or release as the newest versioneven with--pre immediately after a release before the newest nightly isrebuilt. If automation or testing must be done against nightlies or you cannotuse our full index, use the extra indexhttps://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/which only contains nightly artifacts.

The nightly index URLs can also be browsed directly. The--index URL is aPEP 503 simple repository index forpip,and each package has its own sub-directory. For example, you can see the availablejax packages athttps://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax,jax-cuda12-pjrt packages athttps://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax-cuda12-pjrt,andjax-cuda13-pjrt packages athttps://us-python.pkg.dev/ml-oss-artifacts-published/jax-public-nightly-artifacts-registry/simple/jax-cuda13-pjrt.

  • CPU only:

pipinstall-U--prejaxjaxlib-ihttps://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
  • Google Cloud TPU:

pipinstall-U--prejaxjaxliblibtpurequests-ihttps://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/-fhttps://storage.googleapis.com/jax-releases/libtpu_releases.html
  • NVIDIA GPU (CUDA 13):

pipinstall-U--prejaxjaxlib"jax-cuda13-plugin[with-cuda]"jax-cuda13-pjrt-ihttps://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/
  • NVIDIA GPU (CUDA 12):

pipinstall-U--prejaxjaxlib"jax-cuda12-plugin[with-cuda]"jax-cuda12-pjrt-ihttps://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/

Building JAX from source#

Refer toBuilding from source.

Installing olderjaxlib wheels#

Due to storage limitations on the Python package index, the JAX team periodically removesolderjaxlib wheels from the releases on http://pypi.org/project/jax. These canstill be installed directly via the URLs here. For example:

# Install jaxlib on CPU via the wheel archivepipinstall"jax[cpu]==0.3.25"-ihttps://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/# Install the jaxlib 0.3.25 CPU wheel directlypipinstalljaxlib==0.3.25-ihttps://us-python.pkg.dev/ml-oss-artifacts-published/jax/simple/

For specific older GPU wheels, be sure to use thejax_cuda_releases.html URL; for example

pipinstalljaxlib==0.3.25+cuda11.cudnn82-fhttps://storage.googleapis.com/jax-releases/jax_cuda_releases.html

[8]ページ先頭

©2009-2026 Movatter.jp