jax.experimental.custom_partitioning module
Contents
jax.experimental.custom_partitioning module#
API#
- jax.experimental.custom_partitioning.custom_partitioning(fun,static_argnums=())[source]#
Inserts a CustomCallOp into the XLA graph with custom SPMD lowering rules.
@custom_partitioningdeff(*args):return...defpropagate_user_sharding(mesh,user_shape):'''Update the sharding of the op from a user's shape.sharding.'''user_sharding=jax.tree.map(lambdax:x.sharding,user_shape)defpartition(mesh,arg_shapes,result_shape):deflower_fn(*args):...buildscomputationonper-deviceshapes...result_shardings=jax.tree.map(lambdax:x.sharding,result_shape)arg_shardings=jax.tree.map(lambdax:x.sharding,arg_shapes)# result_sharding and arg_shardings may optionally be modified and the# partitioner will insert collectives to reshape.returnmesh,lower_fn,result_sharding,arg_shardingsdefinfer_sharding_from_operands(mesh,arg_shapes,shape):'''Compute the result sharding from the sharding of the operands.'''arg_shardings=jax.tree.map(lambdax:x.sharding,arg_shapes)f.def_partition(partition,propagate_user_sharding,infer_sharding_from_operands=infer_sharding_from_operands,sharding_rule='i j -> 'ij')
The args to
def_partitionare as follows:propagate_user_sharding: Callable which takes the sharding of a user (in the dag)and returns a suggestion for a newNamedSharding. The default value is None.A trivial implementation is just to return the input sharding.partition: Callable which takes the SPMD suggested partition shapes andpartition specs and returns the mesh, a per-shard lowering function, and the finalinput and output sharding specs (the SPMD partitioner will repartition theinputs to match). The mesh is returned to allow configuring axis_names forcollectives when no mesh is provided.infer_sharding_from_operands: Callable which computes an outputNamedShardingfrom theNamedShardingchosen for each argument.decode_shardings: When set to True, convert inputGSPMDSharding``sto``NamedShardingif possible. This may not be possible if the user does notprovide a contextual mesh.sharding_rule: an SdyShardingRule object, an Einsum-like notation stringthat describes the sharding rule, or a Callable that produces either ofthese. We call the index labels in Einsum notation factors in our shardingrule. We borrow the idea from the einops.rearrange string , to use a spaceseparator between factors and allow multiple letters factor names. Bydefault, a factor corresponds to a passthrough/elementwise dimension.Factors corresponding to other dimensions can be specified via keywordarguments described below. Seejax-shardy-guidefor more details and examples.reduction_factors: A tuple of strings, specifying the reduction factorsfor a stringsharding_rule. A reduction factor corresponds to a dimensionthat appears in operands but not in the result, such as the contractingdimensions in a matmul operation. If a reduction factor is sharded, theresult would need to be all-reduced along the same axes.need_replication_factors: A tuple of strings, specifying theneed_replication factors for a stringsharding_rule. A need_replicationfactor corresponds to a dimension that shouldn’t be sharded to supportthe implementation.permutation_factors: A tuple of strings, specifying the permutationfactors for a stringsharding_rule. A permutation factor corresponds to adimension that would trigger collective permute if it is sharded.factor_sizes: A dictionary of variable keyword arguments, specifyingthe sizes of the factors that are only used in compound factors in a stringsharding_rule.
When config.use_shardy_partitioner.value is True,sharding_rule is used;otherwise,propagate_user_sharding andinfer_sharding_from_operands areused.
Positional arguments can be specified as static using static_argnums. JAX uses
inspect.signature(fun)to resolve these positional arguments.Examples
As an example, assume we want to enhance the existing
jax.numpy.fft.fft. This function computesthe discrete Fourier transform of an N-dimensional input along the last dimension, and is batchedalong the first N-1 dimensions.By default, however, it will ignore the sharding of the input and gather the input on all devices.However, sincejax.numpy.fft.fftis batched along the first N-1 dimensions,this is unnecessary. We will create a newmy_fftop that, instead, does not alter the shardingalong the firstN-1 dimensions, and only gathers the input along the last dimension if needed.importjaxfromjax.shardingimportNamedShardingfromjax.experimental.custom_partitioningimportcustom_partitioningfromjax.experimental.pjitimportpjitfromjax.shardingimportPartitionSpecasPfromjax.shardingimportMeshfromjax.numpy.fftimportfftimportregexasreimportnumpyasnp# Pattern to detect all-gather or dynamic-slice in the generated HLO_PATTERN='(dynamic-slice|all-gather)'# For an N-D input, keeps sharding along the first N-1 dimensions# but replicate along the last dimensiondefsupported_sharding(sharding,shape):rank=len(shape.shape)max_shared_dims=min(len(sharding.spec),rank-1)names=tuple(sharding.spec[:max_shared_dims])+tuple(Nonefor_inrange(rank-max_shared_dims))returnNamedSharding(sharding.mesh,P(*names))defpartition(mesh,arg_shapes,result_shape):result_shardings=jax.tree.map(lambdax:x.sharding,result_shape)arg_shardings=jax.tree.map(lambdax:x.sharding,arg_shapes)returnmesh,fft,supported_sharding(arg_shardings[0],arg_shapes[0]),(supported_sharding(arg_shardings[0],arg_shapes[0]),)definfer_sharding_from_operands(mesh,arg_shapes,result_shape):arg_shardings=jax.tree.map(lambdax:x.sharding,arg_shapes)returnsupported_sharding(arg_shardings[0],arg_shapes[0])@custom_partitioningdefmy_fft(x):returnfft(x)# Use Einsum-like notation to specify the sharding rule.my_fft.def_partition(infer_sharding_from_operands=infer_sharding_from_operands,partition=partition,sharding_rule='...i -> ...i')# Use SdyShardingRule object to specify the sharding rule.my_fft.def_partition(infer_sharding_from_operands=infer_sharding_from_operands,partition=partition,sharding_rule=SdyShardingRule(operand_mappings=((BATCHING,'i'),),result_mappings=((BATCHING,'i'),))))
Now create a 2D array sharded along the first axis, pass it through
my_fftand notice how it is still sharded as expected, and identical to the outputoffft. However, inspecting the HLO(usinglower(x).compile().runtime_executable().hlo_modules()) reveals thatmy_fftdoes not create any all-gather or dynamic-slice, whilefftdoes.withMesh(np.array(jax.devices()),('x',)):x=np.asarray(np.random.randn(32*1024,1024),dtype=np.complex64)y=pjit(lambdax:x,in_shardings=None,out_shardings=P('x'))(x)pjit_my_fft=pjit(my_fft,in_shardings=P('x'),out_shardings=P('x'))pjit_fft=pjit(fft,in_shardings=P('x'),out_shardings=P('x'))print(pjit_my_fft(y))print(pjit_fft(y))# dynamic-slice or all-gather are not present in the HLO for my_fft, because x is a 2D arrayassert(re.search(_PATTERN,pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string())isNone)# dynamic-slice or all-gather are present in the HLO for fftassert(re.search(_PATTERN,pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string())isnotNone)
# my_fft[[-38.840824+0.j-40.649452+11.845365j...-1.6937828+0.8402481j15.999859-4.0156755j]]# jax.numpy.fft.fft[[-38.840824+0.j-40.649452+11.845365j...-1.6937828+0.8402481j15.999859-4.0156755j]]
Because of the logic in
supported_sharding,my_fftalso works on 1-dimensional arrays.However, in this case, the HLO ofmy_fftdoes show a dynamic-slice, since the last dimensionis the dimension along which FFTs are calculated and needs to be replicated on all devices beforethe computation can be done.withMesh(np.array(jax.devices()),('x',)):x=np.asarray(np.random.randn(32*1024*1024),dtype=np.complex64)y=pjit(lambdax:x,in_shardings=None,out_shardings=P('x'))(x)pjit_my_fft=pjit(my_fft,in_shardings=P('x'),out_shardings=P('x'))pjit_fft=pjit(fft,in_shardings=P('x'),out_shardings=P('x'))print(pjit_my_fft(y))print(pjit_fft(y))# dynamic-slice or all-gather are present in the HLO for my_fft, because x is a 1D arrayassert(re.search(_PATTERN,pjit_my_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string())isNone)# dynamic-slice or all-gather are present in the HLO for fftassert(re.search(_PATTERN,pjit_fft.lower(x).compile().runtime_executable().hlo_modules()[0].to_string())isnotNone)
# my_fft[7.217285+0.j-3012.4937+4287.635j-405.83594+3042.984j...1422.4502+7271.4297j-405.84033-3042.983j-3012.4963-4287.6343j]# jax.numpy.fft.fft[7.217285+0.j-3012.4937+4287.635j-405.83594+3042.984j...1422.4502+7271.4297j-405.84033-3042.983j-3012.4963-4287.6343j]
