Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Device-local array layout control#

Thejax.experimental.layout package provides ways to controlhow JAX arrays are laid out in device-local memory.

Terminology#

Array layout is tightly coupled with arraysharding.Together, a layout and a sharding fully describes how an array’svalues are laid out across (distributed) memories. Along these lines,we use the following terminology:

  • Layout: how an array’s values are laid out within each memory inwhich they reside (e.g., in the memory of a single devicememory). A typical layout specification is a minor-to-major orderlisting of array dimensions.

  • Sharding: how an array’s values are distributedacrossdifferent memory spaces, such as multiple device memories(e.g. described by sharding some dimensions and replicatingothers).

  • Format: the pairing oflayout andsharding,providing a complete picture of an array’s memory placement.

Types#

There are two Python types that come up when controlling arraylayouts:Layout andFormat.

  • TheLayout class is used to define the in-memorylayout of an array. It has the following key attributes:

    • major_to_minor: A tuple of integers specifying the dimensionordering in memory. For example, for a 2-dimensional array,(0,1)indicates row-major layout and(1,0) indicates column-major.

    • _tiling: An intentionally hidden, highly experimental, optionalattribute to specify a tiled layout.

    • AUTO: A special, static sentinel object that can be used withjax.jit to request that the compiler automatically determinea good layout for a compiled function’s input or output arrays.

  • TheFormat class carries both aLayout and aSharding, witheither one taking on a default value when it is not specified.When the layout is explicitly specified, the sharding must beas well.

JAX API functions, such asjax.jit andjax.device_put, acceptShardings for sharding control orFormats for additional layoutcontrol. They typically do not acceptLayout instances directly.

Specifying and reading layouts#

By passingFormat objects tojax.jit in place of shardings (in thein_shardings andout_shardings arguments), you can guide thecompiler’s layout decisions. Similarly you can passFormats insteadofShardings tojax.device_put to control the layout of theresulting array.

Let’s see an example that uses both explicit and automatic layouts (asinLayout.AUTO). Imagine we have two compiled functions,init_fnandapply_fn. Say we expectinit_fn to be called roughly once, butapply_fn to be called on the output ofinit_fn many times, so thatwe care much more about the performance ofapply_fn. We may want tohave the compiler choose a good layout forapply_fn and constraininit_fn to produce arrays of such layout. We can do this as follows:

importjax,jax.numpyasjnpfromjax.experimental.layoutimportLayout,Formatfromjax.shardingimportSingleDeviceShardingimportnumpyasnpdefinit_fn(x,y):returnx*2,y*3defapply_fn(x,y):returnx[0,:],y[:,0]

Sinceapply_fn reads a contiguous column of its second argumenty,it makes sense to lay it out in column-major order (where columns arestored contiguously). UsingLayout.AUTO, we can ask the compiler toinfer good input layouts and see that it indeed chooses to request thesecond argument in column-major layout.

shape=(4*128,8*128)duck=jax.ShapeDtypeStruct(shape,jnp.float32)# Compile the `apply` function with layouts inferred automaticallyapply_exe=jax.jit(apply_fn,in_shardings=Format(Layout.AUTO),out_shardings=Format(Layout.AUTO),).trace(duck,duck).lower().compile()# Read back the inferred input layoutarg_formats,kwarg_formats=apply_exe.input_formatsassertlen(kwarg_formats)==0assertarg_formats[0].layout.major_to_minor==(0,1)assertarg_formats[1].layout.major_to_minor==(1,0)

We can then compileinit_fn to explicitly match this layout in itsoutputs.

init_exe=jax.jit(init_fn,out_shardings=arg_formats).trace(duck,duck).lower().compile()assertinit_exe.output_formats==arg_formats

Finally we can see how the compiledapply_fn behaves when calledwith differently laid out input arrays. The behavior varies withwhether inputs arecommitted. Asthe following test demonstrates, if the argument arrays are committed,then the pre-compiledapply_fn requires they match the layoutdetermined by the compiler above. Meanwhile it accepts uncommittedarrays of any layout (including, of course, the inferred layout). Inthis case, the arrays may be relaid out prior to invoking the compiledcomputation.

deftest(x,y,msg):print(f'--{msg}:')print('x major_to_minor =',x.format.layout.major_to_minor)print('y major_to_minor =',y.format.layout.major_to_minor)try:apply_exe(x,y)print('-> `apply` called successfully')exceptValueErrorase:assert'does not match'instr(e)print('-> error: mismatched input layouts')print()dev=jax.devices()[0]x1=y1=jnp.ones(shape)test(x1,y1,'uncommitted with mismatched layout')x2,y2=init_exe(x1,y1)test(x2,y2,'uncommitted with matching layout')x3=jnp.ones(shape)y3=jax.device_put(np.ones(shape),Format(Layout(major_to_minor=(1,0)),SingleDeviceSharding(dev)))test(x3,y3,'committed with matching layout')x4=jnp.ones(shape)y4=jax.device_put(np.ones(shape),Format(Layout(major_to_minor=(0,1)),SingleDeviceSharding(dev)))test(x4,y4,'committed with mismatched layout')
-- uncommitted with mismatched layout:x major_to_minor = (0, 1)y major_to_minor = (0, 1)-> `apply` called successfully-- uncommitted with matching layout:x major_to_minor = (0, 1)y major_to_minor = (1, 0)-> `apply` called successfully-- committed with matching layout:x major_to_minor = (0, 1)y major_to_minor = (1, 0)-> `apply` called successfully-- committed with mismatched layout:x major_to_minor = (0, 1)y major_to_minor = (0, 1)-> error: mismatched input layouts

Constraining intermediate layouts#

We can also enforce a specific layout on an intermediate value withina JIT-compiled function usingwith_layout_constraint:

fromjax.experimental.layoutimportwith_layout_constraint@jax.jitdeff(x):y=x.T# Enforce a specific layout on `y`y=with_layout_constraint(y,Layout(major_to_minor=(0,1)))returny*2

This is analogous tojax.lax.with_sharding_constraint,for constraining layouts rather than shardings.


[8]ページ先頭

©2009-2025 Movatter.jp