Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

jax.typing module#

The JAX typing module is where JAX-specific static type annotations live.This submodule is a work in progress; to see the proposal behind the types exportedhere, seehttps://docs.jax.dev/en/latest/jep/12049-type-annotations.html.

The currently-available types are:

  • jax.Array: annotation for any JAX array or tracer (i.e. representations of arrayswithin JAX transforms).

  • jax.typing.ArrayLike: annotation for any value that is safe to implicitly cast toa JAX array; this includesjax.Array,numpy.ndarray, as well as Pythonbuiltin numeric values (e.g.int,float, etc.) and numpy scalar values(e.g.numpy.int32,numpy.float64, etc.)

  • jax.typing.DTypeLike: annotation for any value that can be cast to a JAX-compatibledtype; this includes strings (e.g.‘float32’,‘int32’), scalar types (e.g.float,np.float32), dtypes (e.g.np.dtype(‘float32’)), or objects with a dtype attribute(e.g.jnp.float32,jnp.int32).

We may add additional types here in future releases.

JAX Typing Best Practices#

When annotating JAX arrays in public API functions, we recommend usingArrayLikefor array inputs, andArray for array outputs.

For example, your function might look like this:

importnumpyasnpimportjax.numpyasjnpfromjaximportArrayfromjax.typingimportArrayLikedefmy_function(x:ArrayLike)->Array:# Runtime type validation, Python 3.10 or newer:ifnotisinstance(x,ArrayLike):raiseTypeError(f"Expected arraylike input; got{x}")# Runtime type validation, any Python version:ifnot(isinstance(x,(np.ndarray,Array))ornp.isscalar(x)):raiseTypeError(f"Expected arraylike input; got{x}")# Convert input to jax.Array:x_arr=jnp.asarray(x)# ... do some computation; JAX functions will return Array types:result=x_arr.sum(0)/x_arr.shape[0]# return an Arrayreturnresult

Most of JAX’s public APIs follow this pattern. Note in particular that we recommend JAX functionsto not accept sequences such aslist ortuple in place of arrays, as this cancause extra overhead in JAX transforms likejit() and can behave in unexpected ways withbatch-wise transforms likevmap() orjax.pmap(). For more information on this,seeNon-array inputs NumPy vs JAX

List of Members#

ArrayLike

Type annotation for JAX array-like objects.

DTypeLike

alias ofstr |type[Any] |dtype |SupportsDType


[8]ページ先頭

©2009-2026 Movatter.jp