Rate this Page

DataParallel#

classtorch.nn.DataParallel(module,device_ids=None,output_device=None,dim=0)[source]#

Implements data parallelism at the module level.

This container parallelizes the application of the givenmodule bysplitting the input across the specified devices by chunking in the batchdimension (other objects will be copied once per device). In the forwardpass, the module is replicated on each device, and each replica handles aportion of the input. During the backwards pass, gradients from each replicaare summed into the original module.

The batch size should be larger than the number of GPUs used.

Warning

It is recommended to useDistributedDataParallel,instead of this class, to do multi-GPU training, even if there is only a singlenode. See:Use nn.parallel.DistributedDataParallel instead of multiprocessing or nn.DataParallel andDistributed Data Parallel.

Arbitrary positional and keyword inputs are allowed to be passed intoDataParallel but some types are specially handled. tensors will bescattered on dim specified (default 0). tuple, list and dict types willbe shallow copied. The other types will be shared among different threadsand can be corrupted if written to in the model’s forward pass.

The parallelizedmodule must have its parameters and buffers ondevice_ids[0] before running thisDataParallelmodule.

Warning

In each forward,module isreplicated on each device, so anyupdates to the running module inforward will be lost. For example,ifmodule has a counter attribute that is incremented in eachforward, it will always stay at the initial value because the updateis done on the replicas which are destroyed afterforward. However,DataParallel guarantees that the replica ondevice[0] will have its parameters and buffers sharing storage withthe base parallelizedmodule. Soin-place updates to theparameters or buffers ondevice[0] will be recorded. E.g.,BatchNorm2d andspectral_norm()rely on this behavior to update the buffers.

Warning

Forward and backward hooks defined onmodule and its submoduleswill be invokedlen(device_ids) times, each with inputs located ona particular device. Particularly, the hooks are only guaranteed to beexecuted in correct order with respect to operations on correspondingdevices. For example, it is not guaranteed that hooks set viaregister_forward_pre_hook() be executed beforealllen(device_ids)forward() calls, butthat each such hook be executed before the correspondingforward() call of that device.

Warning

Whenmodule returns a scalar (i.e., 0-dimensional tensor) inforward(), this wrapper will return a vector of length equal tonumber of devices used in data parallelism, containing the result fromeach device.

Note

There is a subtlety in using thepacksequence->recurrentnetwork->unpacksequence pattern in aModule wrapped inDataParallel.SeeMy recurrent network doesn’t work with data parallelism section in FAQ fordetails.

Parameters:
  • module (Module) – module to be parallelized

  • device_ids (list ofint ortorch.device) – CUDA devices (default: all devices)

  • output_device (int ortorch.device) – device location of output (default: device_ids[0])

Variables:

module (Module) – the module to be parallelized

Example:

>>>net=torch.nn.DataParallel(model,device_ids=[0,1,2])>>>output=net(input_var)# input_var can be on any device, including CPU