jax.devices
Contents
jax.devices#
- jax.devices(backend=None)[source]#
Returns a list of all devices for a given backend.
Each device is represented by a subclass of
Device(e.g.CpuDevice,GpuDevice). The length of the returned list isequal todevice_count(backend). Local devices can be identified bycomparingDevice.process_indexto the value returned byjax.process_index().If
backendisNone, returns all the devices from the default backend.The default backend is generally'gpu'or'tpu'if available,otherwise'cpu'.
Contents
