Rate this Page

Frequently Asked Questions#

Created On: Feb 15, 2018 | Last Updated On: Aug 05, 2021

My model reports “cuda runtime error(2): out of memory”#

As the error message suggests, you have run out of memory on yourGPU. Since we often deal with large amounts of data in PyTorch,small mistakes can rapidly cause your program to use up all of yourGPU; fortunately, the fixes in these cases are often simple.Here are a few common things to check:

Don’t accumulate history across your training loop.By default, computations involving variables that require gradientswill keep history. This means that you should avoid using suchvariables in computations which will live beyond your training loops,e.g., when tracking statistics. Instead, you should detach the variableor access its underlying data.

Sometimes, it can be non-obvious when differentiable variables canoccur. Consider the following training loop (abridged fromsource):

total_loss=0foriinrange(10000):optimizer.zero_grad()output=model(input)loss=criterion(output)loss.backward()optimizer.step()total_loss+=loss

Here,total_loss is accumulating history across your training loop, sinceloss is a differentiable variable with autograd history. You can fix this bywritingtotal_loss += float(loss) instead.

Other instances of this problem:1.

Don’t hold onto tensors and variables you don’t need.If you assign a Tensor or Variable to a local, Python will notdeallocate until the local goes out of scope. You can freethis reference by usingdelx. Similarly, if you assigna Tensor or Variable to a member variable of an object, it willnot deallocate until the object goes out of scope. You willget the best memory usage if you don’t hold onto temporariesyou don’t need.

The scopes of locals can be larger than you expect. For example:

foriinrange(5):intermediate=f(input[i])result+=g(intermediate)output=h(result)returnoutput

Here,intermediate remains live even whileh is executing,because its scope extrudes past the end of the loop. To free itearlier, you shoulddelintermediate when you are done with it.

Avoid running RNNs on sequences that are too large.The amount of memory required to backpropagate through an RNN scaleslinearly with the length of the RNN input; thus, you will run out of memoryif you try to feed an RNN a sequence that is too long.

The technical term for this phenomenon isbackpropagation through time,and there are plenty of references for how to implement truncatedBPTT, including in theword language model example; truncation is handled by therepackage function as described inthis forum post.

Don’t use linear layers that are too large.A linear layernn.Linear(m,n) usesO(nm)O(nm) memory: that is to say,the memory requirements of the weightsscales quadratically with the number of features. It is very easytoblow through your memorythis way (and remember that you will need at least twice the size of theweights, since you also need to store the gradients.)

Consider checkpointing.You can trade-off memory for compute by usingcheckpoint.

My GPU memory isn’t freed properly#

PyTorch uses a caching memory allocator to speed up memory allocations. As aresult, the values shown innvidia-smi usually don’t reflect the truememory usage. SeeMemory management for more details about GPUmemory management.

If your GPU memory isn’t freed even after Python quits, it is very likely thatsome Python subprocesses are still alive. You may find them viaps-elf|greppython and manually kill them withkill-9[pid].

My out of memory exception handler can’t allocate memory#

You may have some code that tries to recover from out of memory errors.

try:run_model(batch_size)exceptRuntimeError:# Out of memoryfor_inrange(batch_size):run_model(1)

But find that when you do run out of memory, your recovery code can’t allocateeither. That’s because the python exception object holds a reference to thestack frame where the error was raised. Which prevents the original tensorobjects from being freed. The solution is to move you OOM recovery code outsideof theexcept clause.

oom=Falsetry:run_model(batch_size)exceptRuntimeError:# Out of memoryoom=Trueifoom:for_inrange(batch_size):run_model(1)

My data loader workers return identical random numbers#

You are likely using other libraries to generate random numbers in the datasetand worker subprocesses are started viafork. Seetorch.utils.data.DataLoader’s documentation for how toproperly set up random seeds in workers with itsworker_init_fn option.

My recurrent network doesn’t work with data parallelism#

There is a subtlety in using thepacksequence->recurrentnetwork->unpacksequence pattern in aModule withDataParallel ordata_parallel(). Input to each theforward() oneach device will only be part of the entire input. Because the unpack operationtorch.nn.utils.rnn.pad_packed_sequence() by default only pads up to thelongest input it sees, i.e., the longest on that particular device, sizemismatches will happen when results are gathered together. Therefore, you caninstead take advantage of thetotal_length argument ofpad_packed_sequence() to make sure that theforward() calls return sequences of same length. For example, you canwrite:

fromtorch.nn.utils.rnnimportpack_padded_sequence,pad_packed_sequenceclassMyModule(nn.Module):# ... __init__, other methods, etc.# padded_input is of shape [B x T x *] (batch_first mode) and contains# the sequences sorted by lengths#   B is the batch size#   T is max sequence lengthdefforward(self,padded_input,input_lengths):total_length=padded_input.size(1)# get the max sequence lengthpacked_input=pack_padded_sequence(padded_input,input_lengths,batch_first=True)packed_output,_=self.my_lstm(packed_input)output,_=pad_packed_sequence(packed_output,batch_first=True,total_length=total_length)returnoutputm=MyModule().cuda()dp_m=nn.DataParallel(m)

Additionally, extra care needs to be taken when batch dimension is dim1(i.e.,batch_first=False) with data parallelism. In this case, the firstargument of pack_padded_sequencepadding_input will be of shape[TxBx*] and should be scattered along dim1, but the second argumentinput_lengths will be of shape[B] and should be scattered along dim0. Extra code to manipulate the tensor shapes will be needed.