DataParallel#
- classtorch.nn.DataParallel(module,device_ids=None,output_device=None,dim=0)[source]#
Implements data parallelism at the module level.
This container parallelizes the application of the given
modulebysplitting the input across the specified devices by chunking in the batchdimension (other objects will be copied once per device). In the forwardpass, the module is replicated on each device, and each replica handles aportion of the input. During the backwards pass, gradients from each replicaare summed into the original module.The batch size should be larger than the number of GPUs used.
Warning
It is recommended to use
DistributedDataParallel,instead of this class, to do multi-GPU training, even if there is only a singlenode. See:Use nn.parallel.DistributedDataParallel instead of multiprocessing or nn.DataParallel andDistributed Data Parallel.Arbitrary positional and keyword inputs are allowed to be passed intoDataParallel but some types are specially handled. tensors will bescattered on dim specified (default 0). tuple, list and dict types willbe shallow copied. The other types will be shared among different threadsand can be corrupted if written to in the model’s forward pass.
The parallelized
modulemust have its parameters and buffers ondevice_ids[0]before running thisDataParallelmodule.Warning
In each forward,
moduleisreplicated on each device, so anyupdates to the running module inforwardwill be lost. For example,ifmodulehas a counter attribute that is incremented in eachforward, it will always stay at the initial value because the updateis done on the replicas which are destroyed afterforward. However,DataParallelguarantees that the replica ondevice[0]will have its parameters and buffers sharing storage withthe base parallelizedmodule. Soin-place updates to theparameters or buffers ondevice[0]will be recorded. E.g.,BatchNorm2dandspectral_norm()rely on this behavior to update the buffers.Warning
Forward and backward hooks defined on
moduleand its submoduleswill be invokedlen(device_ids)times, each with inputs located ona particular device. Particularly, the hooks are only guaranteed to beexecuted in correct order with respect to operations on correspondingdevices. For example, it is not guaranteed that hooks set viaregister_forward_pre_hook()be executed beforealllen(device_ids)forward()calls, butthat each such hook be executed before the correspondingforward()call of that device.Warning
When
modulereturns a scalar (i.e., 0-dimensional tensor) inforward(), this wrapper will return a vector of length equal tonumber of devices used in data parallelism, containing the result fromeach device.Note
There is a subtlety in using the
packsequence->recurrentnetwork->unpacksequencepattern in aModulewrapped inDataParallel.SeeMy recurrent network doesn’t work with data parallelism section in FAQ fordetails.- Parameters
module (Module) – module to be parallelized
device_ids (list ofint ortorch.device) – CUDA devices (default: all devices)
output_device (int ortorch.device) – device location of output (default: device_ids[0])
- Variables
module (Module) – the module to be parallelized
Example:
>>>net=torch.nn.DataParallel(model,device_ids=[0,1,2])>>>output=net(input_var)# input_var can be on any device, including CPU