Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
/jaxPublic

cloud_tpu_colabs

Directory actions

More options

Directory actions

More options

Latest commit

 

History

History

cloud_tpu_colabs

The same JAX code that runs on CPU and GPU can also be run on TPU. Cloud TPUshave the advantage of quickly giving you access to multiple TPU accelerators,including inColab. All of theexample notebooks here usejax.pmap to run JAXcomputation across multiple TPU cores from Colab. You can also run the same codedirectly on aCloud TPUVM.

Example Cloud TPU notebooks

The following notebooks showcase how to use and what you can do with Cloud TPUs on Colab:

A guide to getting started withpmap, a transform for easily distributing SPMDcomputations across devices.

Contributed by Alex Alemi (alexalemi@)

Solve and plot parallel ODE solutions withpmap.

Contributed by Stephan Hoyer (shoyer@)

Solve the wave equation withpmap, and make cool movies! The spatial domain is partitioned across the 8 cores of a Cloud TPU.

An overview of JAX presented at theProgram Transformations for ML workshop at NeurIPS 2019 and theCompilers for ML workshop at CGO 2020. Covers basic numpy usage,grad,jit,vmap, andpmap.

Performance notes

Theguidance on running TensorFlow on TPUs applies to JAX as well, with the exception of TensorFlow-specific details. Here we highlight a few important details that are particularly relevant to using TPUs in JAX.

Padding

One of the most common culprits for surprisingly slow code on TPUs is inadvertent padding:

  • Arrays in the Cloud TPU are tiled. This entails padding one of the dimensions to a multiple of 8, and a different dimension to a multiple of 128.
  • The matrix multiplication unit performs best with pairs of large matrices that minimize the need for padding.

bfloat16 dtype

By default*, matrix multiplication in JAX on TPUsuses bfloat16 with float32 accumulation. This can be controlled with theprecision keyword argument on relevantjax.numpy functions (matmul,dot,einsum, etc). In particular:

  • precision=jax.lax.Precision.DEFAULT: uses mixed bfloat16 precision (fastest)
  • precision=jax.lax.Precision.HIGH: uses multiple MXU passes to achieve higher precision
  • precision=jax.lax.Precision.HIGHEST: uses even more MXU passes to achieve full float32 precision

JAX also adds thebfloat16 dtype, which you can use to explicitly cast arrays to bfloat16, e.g.,jax.numpy.array(x, dtype=jax.numpy.bfloat16).

* We might change the default precision in the future, since it is arguably surprising. Please comment/vote onthis issue if it affects you!

Running JAX on a Cloud TPU VM

Refer to theCloud TPU VMdocumentation.

Reporting issues and getting help

If you run into Cloud TPU-specific issues (e.g. trouble creating a Cloud TPUVM), please emailcloud-tpu-support@google.com, ortrc-support@google.com ifyou are aTRC member. You can alsofile aJAX issue orask a discussionquestion for any issues with thesenotebooks or using JAX in general.

If you have any other questions or comments regarding JAX on Cloud TPUs, pleaseemailjax-cloud-tpu-team@google.com. We’d like to hear from you!


[8]ページ先頭

©2009-2025 Movatter.jp