Movatterモバイル変換


[0]ホーム

URL:


Skip to main content
Ctrl+K
JAX  documentation - Home

Colocated Python#

NOTE: Colocated Python is currently an experimental API. Its functionality andinterface are subject to change without following the standard JAX compatibilitypolicy.

Colocated Python provides a uniform way to run Python code on the hostsassociated with a set of JAX devices. If the JAX devices represent localdevices, the Python code will run on the local host. If the JAX devicesrepresent remote devices, the Python code will be shipped to run on the host ofthese remote devices. This is useful when building a multi-host ML system on topof JAX that is portable across multi-controller JAX environments (running JAXcode on each host with accelerators) as well as single-controller JAXenvironments (running JAX code on a single host orchestrating other hosts withaccelerators).

Colocated CPU devices#

To use colocated Python, the first step is to obtain CPU devices colocated withtarget accelerator devices.jax.experimental.colocated_python.colocated_cpu_devices provides a standardway to do so.

importjaximportjax.experimental.colocated_pythonascolocated_pythondevices=jax.devices()cpu_devices=colocated_python.colocated_cpu_devices(devices)print(cpu_devices)
[CpuDevice(id=0)]

As usual, the CPU devices can be used with JAX APIs.

cpu_mesh=jax.sharding.Mesh(cpu_devices,["x"])cpu_sharding=jax.sharding.NamedSharding(cpu_mesh,jax.P())x=jax.device_put(1,cpu_sharding)y=jax.jit(lambdax:x+1)(x)print(y)
2

Colocated Python function#

CPU devices can also be used to run Python code with colocated Python.

deff(x):returnx+1f=colocated_python.colocated_python(f)y=f(x)asserty.sharding==x.shardingprint(y)
2

Since colocated Python runs normal Python code, you can also perform I/O:

deff(x):withopen('/tmp/foo','w')asf:f.write(str(x))returnxf=colocated_python.colocated_python(f)jax.block_until_ready(f(x))
Array(1, dtype=int32, weak_type=True)

Note the use ofjax.block_until_ready to ensure the Python code hascompleted. In principle, colocated Python calls may run asynchronously, similarto jitted function calls; the calls would return JAX arrays and do not blockuntil their output is produced. Thus, you should block on an output from acolocated Python call if the completion of the execution is significant.

There exist cases where a colocated Python call runs synchronously.

  • If the colocated Python function is called without “specialization” (seebelow), the very first call will run synchronously. This is because the shapeand sharding of the output must be known for asynchronous execution, andcolocated Python has to run the Python code once to discover this information.

  • Some JAX backends do not yet fully support asynchronous execution, and willfall back to synchronous execution.

The wrapped Python code must use exactly the same set of devices in the inputand the output. This is a requirement similar to jitted functions that representan SPMD execution.

Specialization#

Specialization in colocated Python is a mechanism to supply extra informationabout the input, output, and execution of a colocated Python function, when theinformation cannot be inferred in advance, or you would like to ensure thecolocated Python executions to happen precisely as specified.

First, functions wrapped in colocated Python has aspecialize method.This method is used to create another colocated Python wrapped functionspecialized with the supplied information.

out_specs_fn is a function that takes a pytree ofjax.ShapeDtypeStruct of the call inputs and returns a pytree ofjax.ShapeDtypeStruct expected for the output. Calling this function isanalogous to jitted function tracing, but this function is separate from theoriginal Python code. This function runs on the caller side and not executed onthe devices.

deff(x):returnx+1f=colocated_python.colocated_python(f)f=f.specialize(out_specs_fn=lambdax:x)y=f(x)asserty.sharding==x.sharding

in_specs takes a concrete pytree (the top level is tuple) ofjax.sharding.ShapeDtypeStruct expected for the input to the colocatedPython function call. This is used if a certain input spec must be used, or theoutput specs function can be computed only for a concrete input spec.

importjax.numpyasjnpdeff(x):returnx+1f=colocated_python.colocated_python(f)f=f.specialize(in_specs=(# args(jax.ShapeDtypeStruct(shape=(),dtype=jnp.int32,sharding=cpu_sharding),),# kwargs{},),out_specs_fn=lambdax:jax.ShapeDtypeStruct(shape=(),dtype=jnp.int32,sharding=cpu_sharding),)f(x)# `x` must match the input spec.
Array(2, dtype=int32, weak_type=True)

devices specifies a list of devices that the colocated Python functionshould run on. Havingdevices specialized lets a colocated Python functionwithout input arguments run.

deff():withopen('/tmp/foo','w')asf:f.write('foo')returnf=colocated_python.colocated_python(f)f=f.specialize(devices=cpu_devices)f()# Would be an error if `f` is not specialized with ``devices``.

Colocated Python class#

Colocated Python also supports wrapping Python classes. A real instance iscreated on the hosts associated with the devices, and the caller side will get awrapper class that forwards all method calls to the real instance usingcolocated Python.

classAdder:def__init__(self,increment):print('Adder created')self.increment=incrementdef__del__(self):print('Adder destroyed')defadd(self,x):returnx+self.incrementAdder=colocated_python.colocated_python_class(Adder)adder=Adder(1)x=jax.device_put(1,cpu_sharding)y=adder.add(x)print(y)
Adder created2

When the wrapper class instance is destroyed, the real instance is destroyed aswell. Note that this destruction will be asynchronous.

deladder
Adder destroyed

There are a few important semantic differences between colocated Python andnormal Python.

  • A colocated Python class instance is created only on the hosts associated withthe devices when any non-constructor method is called for the first time. Inthe above example,Adder(1) captures the constructor arguments1, but the actual constructor callAdder(1) on the hostshappens only when the firstadder.add(x) call is made. This is becauseit is unknown what hosts theAdder instance should be created on untilthere is a call to its method.

  • If the method(s) of the same wrapper class is called with inputs withdifferent devices, the real instance may be created at different times ondifferent hosts. If the first method call used CPU devices on host A, and thesecond method call used CPU devices on host B, the real instance will becreated on host A during the first method call, and then on host B during thesecond method call.

  • The methods of colocated Python classes are not yet specializable. The supportwill be added in the future.

Execution order and concurrency#

Colocated Python provides “program order” execution. Even if colocated Pythoncalls may be asynchronous (returning output JAX arrays without blocking), thecalls will be executed in the same order as the order the calls are made in theuser program. Thus, by default, colocated Python calls are sequentiallyexecuted.

Several use cases of colocated Python will benefit from concurrent execution.For example, one colocated Python call may take long time to return because itmay be doing expensive file reads, while another colocated Python call may needto do file writes that are independent from the first one. This situation couldexpect two calls to run concurrently without blocking each other.

Colocated Python provides concurrent execution if colocated Python calls aremade from different threads. For example, the below example would make twocolocated Python calls to run concurrently.

importconcurrent.futuresimporttimedeff(x):time.sleep(1)returnx+1f=colocated_python.colocated_python(f)f=f.specialize(out_specs_fn=lambdax:x)# Calls will be asynchronous.withconcurrent.futures.ThreadPoolExecutor(2)asexecutor:fut1=executor.submit(f,x)fut2=executor.submit(f,x)# Will finish in approximately 1 second instead of 2 seconds.jax.block_until_ready([fut1.result(),fut2.result()])

While calls from different threads run concurrently, on each thread, programordering will continue to apply.


[8]ページ先頭

©2009-2025 Movatter.jp