Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Shape polymorphism#

When JAX is used in JIT mode, a function will be traced, lowered to StableHLO, and compiled for eachcombination of input types and shapes. After exporting a function anddeserializing it on another system we don’t have the Python sources available anymore,so we cannot re-trace and re-lower it.Shape polymorphism is a feature of JAX exportto allow some exported functions to be used for a whole family of input shapes.These functions are traced and lowered once, during exporting, andExportedobject contains the information needed to be able to compile and execute the functionon many concrete input shapes. We do this by specifying shapes that containdimension variables (symbolic shapes) when exporting, as in thefollowing example:

>>>importjax>>>fromjaximportexport>>>fromjaximportnumpyasjnp>>>deff(x):# f: f32[a, b]...returnjnp.concatenate([x,x],axis=1)>>># We construct symbolic dimension variables.>>>a,b=export.symbolic_shape("a, b")>>># We can use the symbolic dimensions to construct shapes.>>>x_shape=(a,b)>>>x_shape(a, b)>>># Then we export with symbolic shapes:>>>exp:export.Exported=export.export(jax.jit(f))(...jax.ShapeDtypeStruct(x_shape,jnp.int32))>>>exp.in_avals(ShapedArray(int32[a,b]),)>>>exp.out_avals(ShapedArray(int32[a,2*b]),)>>># We can later call with concrete shapes (with a=3 and b=4), without re-tracing `f`.>>>res=exp.call(np.ones((3,4),dtype=np.int32))>>>res.shape(3, 8)

Note that such functions are still re-compiled on demand foreach concrete input shape they are invoked on. Only thetracing and the lowering are saved.

Thejax.export.symbolic_shape() is used in the aboveexample to parse a string representation of a symbolic shapeinto dimension expressions objects (of type_DimExpr) that are usable in place of integerconstants to construct shapes. The dimension expression objectsoverload most integer operators, so you can use them asyou’d use integer constants in most cases.SeeComputing with dimension variables for more details.

Additionally, we provide thejax.export.symbolic_args_specs() thatcan be used to construct pytrees ofjax.ShapeDtypeStruct objects basedon a polymorphic shape specification:

>>>deff1(x,y):# x: f32[a, 1], y : f32[a, 4]...returnx+y>>># Assuming you have some actual args with concrete shapes>>>x=np.ones((3,1),dtype=np.int32)>>>y=np.ones((3,4),dtype=np.int32)>>>args_specs=export.symbolic_args_specs((x,y),"a, ...")>>>exp=export.export(jax.jit(f1))(*args_specs)>>>exp.in_avals(ShapedArray(int32[a,1]), ShapedArray(int32[a,4]))

Note how the polymorphic shape specification"a,..." containsthe placeholder... to be filled from the concrete shapes ofthe concrete shapes of the arguments(x,y).The placeholder... stands for 0 or more dimensions, while theplaceholder_ stands for one dimension.Thejax.export.symbolic_args_specs() supports pytrees of arguments,which are used to fill-in the dtypes and any placeholders.The function will construct a pytree ofargument specifications (jax.ShapeDtypeStruct)matching the structure of the arguments passed to it.The polymorphic shapes specification can be apytree prefix in cases where one specification should applyto multiple arguments, as in the above example.Seehow optional parameters are matched to arguments.

A few examples of shape specifications:

  • ("(b,_,_)",None) can be used for a function with two arguments, the firstbeing a 3D array with a batch leading dimension that should be symbolic.The other dimensions for thefirst argument and the shape of the second argument are specialized based on the actualarguments. Note that the same specification would work if the firstargument is a pytree of 3D arrays, all with the same leading dimensionbut possibly with different trailing dimensions.The valueNone for the second argument means that the argumentis not symbolic. Equivalently, one can use....

  • ("(batch,...)","(batch,)") specifies that the two argumentshave matching leading dimensions, the first argument has rank atleast 1, and the second has rank 1.

Correctness of shape polymorphism#

We want to trust that the exported program produces the same results as theoriginal JAX program when compiled and executed for any applicable concrete shapes.More precisely:

For any JAX functionf and any argument specificationarg_spec containing asymbolic shape, and any concrete argumentarg whose shape matchesarg_spec:

  • If the JAX native execution succeeds on the concrete argument:res=f(arg),

  • and if the exporting succeeds with symbolic shapes:exp=export.export(f)(arg_spec),

  • then compiling and running the export will succeed with the same result:res==exp.call(arg)

It is crucial to understand thatf(arg) has the freedom to re-invokethe JAX tracing machinery,and in fact it does so for each distinct concretearg shape,while the execution ofexp.call(arg) cannot use JAX tracing anymore(this execution may happen in an environment where the source codeoff is not available).

Ensuring this form of correctness is hard, and in the hardest casesexporting fails. The rest of this chapter describes how to handle these failures.

Computing with dimension variables#

JAX keeps track of the shapes of all intermediate results. When those shapes dependon dimension variables JAX computes them as symbolic dimension expressionsinvolving dimension variables.Dimension variables stand for integer values greater or equal to 1.The symbolic expressions can represent the resultof applying arithmetic operators (add, sub, mul, floordiv, mod,including the NumPy variantsnp.sum,np.prod, etc.)on dimensionexpressions and integers (int,np.int, or anything convertible byoperator.index).These symbolic dimensions can then be used in shape-parameters of JAX primitivesand APIs, e.g., injnp.reshape,jnp.arange, slicing indices, etc.

For example, in the following code to flatten a 2D array, the computationx.shape[0]*x.shape[1] computes the symbolic dimension4*b as thenew shape:

>>>f=lambdax:jnp.reshape(x,(x.shape[0]*x.shape[1],))>>>arg_spec=jax.ShapeDtypeStruct(export.symbolic_shape("b, 4"),jnp.int32)>>>exp=export.export(jax.jit(f))(arg_spec)>>>exp.out_avals(ShapedArray(int32[4*b]),)

It is possible to convert dimension expressions explicitlyto JAX arrays, withjnp.array(x.shape[0]) or evenjnp.array(x.shape).The result of these operations can be used as regular JAX arrays,but cannot be used anymore as dimensions in shapes, e.g., inreshape:

>>>exp=export.export(jax.jit(lambdax:jnp.array(x.shape[0])+x))(...jax.ShapeDtypeStruct(export.symbolic_shape("b"),np.int32))>>>exp.call(jnp.arange(3,dtype=np.int32))Array([3, 4, 5], dtype=int32)>>>exp=export.export(jax.jit(lambdax:x.reshape(jnp.array(x.shape[0])+2)))(...jax.ShapeDtypeStruct(export.symbolic_shape("b"),np.int32))Traceback (most recent call last):TypeError:Shapes must be 1D sequences of concrete values of integer type, got [Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>].

When a symbolic dimension is used in arithmetic operations withnon-integers,e.g.,float,np.float,np.ndarray, or JAX arrays, it is automaticallyconverted to a JAX array usingjnp.array.For example, in the function below all occurrences ofx.shape[0]are converted implicitly tojnp.array(x.shape[0]) becausethey are involved in operations with non-integer scalars or withJAX arrays:

>>>exp=export.export(jax.jit(...lambdax:(5.+x.shape[0],...x.shape[0]-np.arange(5,dtype=jnp.int32),...x+x.shape[0]+jnp.sin(x.shape[0]))))(...jax.ShapeDtypeStruct(export.symbolic_shape("b"),jnp.int32))>>>exp.out_avals(ShapedArray(float32[], weak_type=True), ShapedArray(int32[5]), ShapedArray(float32[b], weak_type=True))>>>exp.call(jnp.ones((3,),jnp.int32)) (Array(8., dtype=float32, weak_type=True),  Array([ 3, 2, 1, 0, -1], dtype=int32),  Array([4.14112, 4.14112, 4.14112], dtype=float32, weak_type=True))

Another typical example is when computing averages(observe howx.shape[0] is automatically turned into a JAX array):

>>>exp=export.export(jax.jit(...lambdax:jnp.sum(x,axis=0)/x.shape[0]))(...jax.ShapeDtypeStruct(export.symbolic_shape("b, c"),jnp.int32))>>>exp.call(jnp.arange(12,dtype=jnp.int32).reshape((3,4)))Array([4., 5., 6., 7.], dtype=float32)

Errors in presence of shape polymorphism#

Most JAX code assumes that the shapes of JAX arrays are tuples of integers,but with shape polymorphism some dimensions may be symbolic expressions.This can lead to a number of errors. For example, we can have the usualJAX shape check errors:

>>>v,=export.symbolic_shape("v,")>>>export.export(jax.jit(lambdax,y:x+y))(...jax.ShapeDtypeStruct((v,),dtype=np.int32),...jax.ShapeDtypeStruct((4,),dtype=np.int32))Traceback (most recent call last):TypeError:add got incompatible shapes for broadcasting: (v,), (4,).>>>export.export(jax.jit(lambdax:jnp.matmul(x,x)))(...jax.ShapeDtypeStruct((v,4),dtype=np.int32))Traceback (most recent call last):TypeError:dot_general requires contracting dimensions to have the same shape, got (4,) and (v,).

We can fix the above matmul example by specifying that theargument has shape(v,v).

Comparison of symbolic dimensions is partially supported#

Inside JAX there are a number of equality and inequality comparisonsinvolving shapes, e.g., for doing shape checking or even for choosingthe implementation for some primitives. Comparisons are supportedas follows:

  • equality is supported with a caveat: if the two symbolic dimensions denote the samevalue under all valuations for dimension variables, then equality evaluates toTrue,e.g., forb+b==2*b; otherwise the equality evaluates toFalse.Seebelowfor a discussion of important consequences of this behavior.

  • disequality is always the negation of equality.

  • inequality is partially supported, in a similar way as partial equality.However, in thiscase we take into consideration that dimension variables range over strictly positiveintegers. E.g.,b>=1,b>=0,2*a+b>=3 areTrue, whileb>=2,a>=b,a-b>=0 are inconclusive and result in an exception.

In cases where a comparison operation cannot be resolved to a boolean,we raiseInconclusiveDimensionOperation. E.g.,

importjax>>>export.export(jax.jit(lambdax:0ifx.shape[0]+1>=x.shape[1]else1))(...jax.ShapeDtypeStruct(export.symbolic_shape("a, b"),dtype=np.int32))# doctest: +IGNORE_EXCEPTION_DETAILTraceback(mostrecentcalllast):jax._src.export.shape_poly.InconclusiveDimensionOperation:Symbolicdimensioncomparison'a + 1'>='b'isinconclusive.Thiserrorarisesforcomparisonoperationswithshapesthatarenon-constant,andtheresultoftheoperationcannotberepresentedasabooleanvalueforallvaluesofthesymbolicdimensionsinvolved.

If you do get aInconclusiveDimensionOperation, you can tryseveral strategies:

  • If your code uses the built-inmax ormin, or thenp.max ornp.min then you can replace those withcore.max_dim andcore.min_dim, which have the effectof delaying the inequality comparison to the compilationtime, when shapes become known.

  • Try to rewrite conditionals usingcore.max_dim andcore.min_dim, e.g., instead ofdifd>0else0you can writecore.max_dim(d,0).

  • Try to rewrite the code to be less dependent on the factthat dimensions should be integers, and rely on the factthat symbolic dimensions duck-type as integers for mostarithmetic operations. E.g., instead ofint(d)+5 writed+5.

  • Specify symbolic constraints, as explained below.

User-specified symbolic constraints#

By default, JAX assumes that all dimension variables rangeover values greater-or-equal to 1, and it tries to deriveother simple inequalities from that, e.g.:

  • a+2>=3,

  • a*2>=1,

  • a+b+c>=3,

  • a//4>=0,a**2>=1, and so on.

You can avoid some inequality comparison failures if youchange the symbolic shape specifications to addimplicit constraintsfor dimension sizes. E.g.,

  • You can use2*b for a dimension to constrain it to be even and greater or equalto 2.

  • You can useb+15 for a dimension to constrain it tobe at least 16. E.g., the following code would fail withoutthe+15 part, because JAX will want to verify that slice sizesare at most as large as the axis size.

>>>_=export.export(jax.jit(lambdax:x[0:16]))(...jax.ShapeDtypeStruct(export.symbolic_shape("b + 15"),dtype=np.int32))

Such implicit symbolic constraints are used for deciding comparisons and arechecked at compile time, as explainedbelow.

You can also specifyexplicit symbolic constraints:

>>># Introduce dimension variable with constraints.>>>a,b=export.symbolic_shape("a, b",...constraints=("a >= b","b >= 16"))>>>_=export.export(jax.jit(lambdax:x[:x.shape[1],:16]))(...jax.ShapeDtypeStruct((a,b),dtype=np.int32))

The constraints form a conjunction together with the implicitconstraints. You can specify>=,<=, and== constraints.At the moment, JAX has limited support for reasoning withsymbolic constraints:

  • You get the most from constraints of the formof a variable being greater-or-equal orless-or-equal to a constant.For example, from the constraints thata>=16 andb>=8 we can inferthata+2*b>=32.

  • You get limited power when the constraint involvesmore complex expressions, e.g., froma>=b+8 wecan infer thata-b>=8 but not thata>=9.We may improve somewhat this area in the future.

  • Equality constraints are treated as rewrite rules:whenever the symbolic expression on the left of==is encountered, it is rewritten to the expression onthe right.E.g.,floordiv(a,b)==c works by replacing alloccurrences offloordiv(a,b) withc.Equality constraints must not contain addition orsubtraction at the top-level on the left-hand-side. Examples ofvalid left-hand-sides area*b, or4*a, orfloordiv(a+c,b).

>>># Introduce dimension variable with equality constraints.>>>a,b,c,d=export.symbolic_shape("a, b, c, d",...constraints=("a * b == c + d",))>>>2*b*a2*d + 2*c>>>a*b*bb*d + b*c

The symbolic constraints can also help to work around thelimitations in the JAX reasoning mechanisms.For example, in the code below JAX will attempt to prove thatthe slice sizex.shape[0]%3, which is the symbolic expressionmod(b,3), is less or equal to the axis size, which isb.This happens to be true for all strictly positive values ofb, but it is not something JAX’s symbolic comparison rulescan prove. Hence, the following code raises an error:

fromjaximportlax>>>b,=export.symbolic_shape("b")>>>f=lambdax:lax.slice_in_dim(x,0,x.shape[0]%3)>>>export.export(jax.jit(f))(...jax.ShapeDtypeStruct((b,),dtype=np.int32))# doctest: +IGNORE_EXCEPTION_DETAILTraceback(mostrecentcalllast):jax._src.export.shape_poly.InconclusiveDimensionOperation:Symbolicdimensioncomparison'b'>='mod(b, 3)'isinconclusive.Thiserrorarisesforcomparisonoperationswithshapesthatarenon-constant,andtheresultoftheoperationcannotberepresentedasabooleanvalueforallvaluesofthesymbolicdimensionsinvolved.

One option here would be to restrict the code to work only onaxis sizes that are multiple of3 (by replacingb with3*b in the shape). Then, JAX would be ableto simplify the modulo operationmod(3*b,3) to0.Another option is to add a symbolic constraintwith the exact inconclusive inequality that JAXis attempting to prove:

>>>b,=export.symbolic_shape("b",...constraints=["b >= mod(b, 3)"])>>>f=lambdax:lax.slice_in_dim(x,0,x.shape[0]%3)>>>_=export.export(jax.jit(f))(...jax.ShapeDtypeStruct((b,),dtype=np.int32))

Just like the implicit constraints, the explicitsymbolic constraints are checked at compile time,using the same mechanism as explainedbelow.

Symbolic dimension scopes#

The symbolic constraints are stored in αnjax.export.SymbolicScope object, which is created implicitlyfor each call tojax.export.symbolic_shapes(). You must be carefulto not mix symbolic expressions that use different scopes.For example,the following code will fail becausea1 anda2use different scopes (created by different invocations ofjax.export.symbolic_shape()):

>>>a1,=export.symbolic_shape("a,")>>>a2,=export.symbolic_shape("a,",constraints=("a >= 8",))>>>a1+a2Traceback (most recent call last):ValueError:Invalid mixing of symbolic scopes for linear combination.Expected  scope 4776451856 created at <doctest shape_poly.md[31]>:1:6 (<module>)and found for 'a' (unknown) scope 4776979920 created at <doctest shape_poly.md[32]>:1:6 (<module>) with constraints:  a >= 8

The symbolic expressions that originate from a single calltojax.export.symbolic_shape() share a scope andcan be mixed up in arithmetic operations. The result wouldalso share the same scope.

You can reuse scopes:

>>>a,=export.symbolic_shape("a,",constraints=("a >= 8",))>>>b,=export.symbolic_shape("b,",scope=a.scope)# Reuse the scope of `a`>>>a+b# Allowedb + a

You can also create scopes explicitly:

>>>my_scope=export.SymbolicScope()>>>c,=export.symbolic_shape("c",scope=my_scope)>>>d,=export.symbolic_shape("d",scope=my_scope)>>>c+d# Allowedd + c

JAX tracing uses caches keyed partially by shapes, andsymbolic shapes that are printed identically will be considereddistinct if they use different scopes.

Caveat for equality comparisons#

The equality comparison returnsFalse forb+1==b orb==0(in which case it is certain that the dimensions are different for all valuesof the dimension variables),but also forb==1 and fora==b. This is unsound, and weought to raisecore.InconclusiveDimensionOperation because undersome valuations the result should beTrue and under othervaluations it should beFalse. We choose to make equality totalthus allowing unsoundness because otherwise we may get spurious errorsin presence of hash collisionswhen hashing dimension expressions or objects that includethem (shapes,core.AbstractValue,core.Jaxpr).Besides the hashing errors, a partial semantics of equalityleads to errors for the following expressionsb==aorb==b orbin[a,b]even though the error is avoided if we change the order of the comparisons.

Code of the formifx.shape[0]!=1:raiseNiceErrorMessage is sound evenwith this treatment of equality, but code of the formifx.shape[0]!=1:return1is unsound.

Dimension variables must be solvable from the input shapes#

Currently, the only way to pass the values of dimension variableswhen an exported object is invoked is indirectly through the shapesof the array arguments. E.g., the value ofb can be inferred at thecall site from the shape of the first argument of typef32[b].This works well for most use cases, andit mirrors the calling convention of JIT functions.

Sometimes you may want to export a function parameterizedby an integer value that determines some shapes in the program.For example, we maywant to export the functionmy_top_k defined below,parameterized by thevalue ofk, which determines the shape of the result.The following attempt will lead to an error since the dimensionvariablek cannot be derived from the shape of the inputx:i32[4,10]:

>>>defmy_top_k(k,x):# x: i32[4, 10], k <= 10...returnlax.top_k(x,k)[0]# : i32[4, 3]>>>x=np.arange(40,dtype=np.int32).reshape((4,10))>>># Export with static `k=3`. Since `k` appears in shapes it must be in `static_argnums`.>>>exp_static_k=export.export(jax.jit(my_top_k,static_argnums=0))(3,x)>>>exp_static_k.in_avals[0]ShapedArray(int32[4,10])>>>exp_static_k.out_avals[0]ShapedArray(int32[4,3])>>># When calling the exported function we pass only the non-static arguments>>>exp_static_k.call(x)Array([[ 9,  8,  7],       [19, 18, 17],       [29, 28, 27],       [39, 38, 37]], dtype=int32)>>># Now attempt to export with symbolic `k` so that we choose `k` after export.>>>k,=export.symbolic_shape("k",constraints=["k <= 10"])>>>export.export(jax.jit(my_top_k,static_argnums=0))(k,x)Traceback (most recent call last):UnexpectedDimVar:"Encountered dimension variable 'k' that is not appearing in the shapes of the function arguments

In the future, we may add an additional mechanism to pass the values ofdimension variables, besides implicitly through the input shapes.Meanwhile, the workaround for the above use case is to replace thefunction parameterk with an array of shape(0,k), so thatk can be derived from the input shape of an array.The first dimension is 0 to ensure that the whole array is emptyand there is no performance penalty when we call the exported function.

>>>defmy_top_k_with_dimensions(dimensions,x):# dimensions: i32[0, k], x: i32[4, 10]...returnmy_top_k(dimensions.shape[1],x)>>>exp=export.export(jax.jit(my_top_k_with_dimensions))(...jax.ShapeDtypeStruct((0,k),dtype=np.int32),...x)>>>exp.in_avals(ShapedArray(int32[0,k]), ShapedArray(int32[4,10]))>>>exp.out_avals[0]ShapedArray(int32[4,k])>>># When we invoke `exp` we must construct and pass an array of shape (0, k)>>>exp.call(np.zeros((0,3),dtype=np.int32),x)Array([[ 9,  8,  7],       [19, 18, 17],       [29, 28, 27],       [39, 38, 37]], dtype=int32)

Another situation when you may get an error is when some dimensionvariables do appear in the input shapes, but in a non-linearexpression that JAX cannot currently solve:

>>>a,=export.symbolic_shape("a")>>>export.export(jax.jit(lambdax:x.shape[0]))(...jax.ShapeDtypeStruct((a*a,),dtype=np.int32))Traceback (most recent call last):ValueError:Cannot solve for values of dimension variables {'a'}.We can only solve linear uni-variate constraints.Using the following polymorphic shapes specifications: args[0].shape = (a^2,).Unprocessed specifications: 'a^2' for dimension size args[0].shape[0].

Shape assertion errors#

JAX assumes that dimension variables range over strictly positive integers,and this assumption is checked when the code is compiled for concreteinput shapes.

For example, given the symbolic input shape(b,b,2*d),JAX will generate code to check the following assertions wheninvoked with actual argumentarg:

  • arg.shape[0]>=1

  • arg.shape[1]==arg.shape[0]

  • arg.shape[2]%2==0

  • arg.shape[2]//2>=1

For example, here is the error we get when we call the exportedon an argument of shape(3,3,5):

>>>deff(x):# x: f32[b, b, 2*d]...returnx>>>exp=export.export(jax.jit(f))(...jax.ShapeDtypeStruct(export.symbolic_shape("b, b, 2*d"),dtype=np.int32))>>>exp.call(np.ones((3,3,5),dtype=np.int32))Traceback (most recent call last):ValueError:Input shapes do not match the polymorphic shapes specification.Division had remainder 1 when computing the value of 'd'.Using the following polymorphic shapes specifications:  args[0].shape = (b, b, 2*d).Obtained dimension variables: 'b' = 3 from specification 'b' for dimension args[0].shape[0] (= 3), .Please see https://docs.jax.dev/en/latest/export/shape_poly.html#shape-assertion-errors for more details.

These errors arise in a pre-processing step before thecompilation.

Debugging#

First, see theDebugging documentation.Additionally, you can debug the shape refinement, which isinvoked at compilation time for modules that have dimension variables or multi-platformsupport.

If there is an error during shape refinement, you can set theJAX_DUMP_IR_TOenvironment variable to see a dump of the HLO module beforeshape refinement (named..._before_refine_polymorphic_shapes.mlir).This module should already have static input shapes.

To enable the logging of all stages of shape refinement you can set theenvironment variableTF_CPP_VMODULE=refine_polymorphic_shapes=3 in OSS(inside Google, you pass--vmodule=refine_polymorphic_shapes=3):

# Log from pythonJAX_DUMP_IR_TO=/tmp/export.dumps/TF_CPP_VMODULE=refine_polymorphic_shapes=3pythontests/shape_poly_test.pyShapePolyTest.test_simple_unary-v=3

[8]ページ先頭

©2009-2026 Movatter.jp