CPU threading and TorchScript inference

PyTorch allows using multiple CPU threads during TorchScript model inference.The following figure shows different levels of parallelism one would find in atypical application:

../_images/cpu_threading_torchscript_inference.svg

One or more inference threads execute a model’s forward pass on the given inputs.Each inference thread invokes a JIT interpreter that executes the opsof a model inline, one by one. A model can utilize afork TorchScriptprimitive to launch an asynchronous task. Forking several operations at onceresults in a task that is executed in parallel. Thefork operator returns afuture object which can be used to synchronize on later, for example:

@torch.jit.scriptdefcompute_z(x):returntorch.mm(x,self.w_z)@torch.jit.scriptdefforward(x):# launch compute_z asynchronously:fut=torch.jit._fork(compute_z,x)# execute the next operation in parallel to compute_z:y=torch.mm(x,self.w_y)# wait for the result of compute_z:z=torch.jit._wait(fut)returny+z

PyTorch uses a single thread pool for the inter-op parallelism, this thread poolis shared by all inference tasks that are forked within the application process.

In addition to the inter-op parallelism, PyTorch can also utilize multiple threadswithin the ops (intra-op parallelism). This can be useful in many cases,including element-wise ops on large tensors, convolutions, GEMMs, embeddinglookups and others.

Build options

PyTorch uses an internal ATen library to implement ops. In addition to that,PyTorch can also be built with support of external libraries, such asMKL andMKL-DNN,to speed up computations on CPU.

ATen, MKL and MKL-DNN support intra-op parallelism and depend on thefollowing parallelization libraries to implement it:

  • OpenMP - a standard (and a library, usually shipped with a compiler), widely used in external libraries;

  • TBB - a newer parallelization library optimized for task-based parallelism and concurrent environments.

OpenMP historically has been used by a large number of libraries. It is knownfor a relative ease of use and support for loop-based parallelism and other primitives.At the same time OpenMP is not known for a good interoperability with other threadinglibraries used by the application. In particular, OpenMP does not guarantee that a single per-process intra-op threadpool is going to be used in the application. On the contrary, two different inter-opthreads will likely use different OpenMP thread pools for intra-op work.This might result in a large number of threads used by the application.

TBB is used to a lesser extent in external libraries, but, at the same time,is optimized for the concurrent environments. PyTorch’s TBB backend guarantees thatthere’s a separate, single, per-process intra-op thread pool used by all of theops running in the application.

Depending of the use case, one might find one or another parallelizationlibrary a better choice in their application.

PyTorch allows selecting of the parallelization backend used by ATen and otherlibraries at the build time with the following build options:

Library

Build Option

Values

Notes

ATen

ATEN_THREADING

OMP (default),TBB

MKL

MKL_THREADING

(same)

To enable MKL useBLAS=MKL

MKL-DNN

MKLDNN_THREADING

(same)

To enable MKL-DNN useUSE_MKLDNN=1

It is strongly recommended not to mix OpenMP and TBB within one build.

Any of theTBB values above requireUSE_TBB=1 build setting (default: OFF).A separate settingUSE_OPENMP=1 (default: ON) is required for OpenMP parallelism.

Runtime API

The following API is used to control thread settings:

Type of parallelism

Settings

Notes

Inter-op parallelism

at::set_num_interop_threads,at::get_num_interop_threads (C++)

set_num_interop_threads,get_num_interop_threads (Python,torch module)

set* functions can only be called once and onlyduring the startup, before the actual operators running;

Default number of threads: number of CPU cores.

Intra-op parallelism

at::set_num_threads,at::get_num_threads (C++)set_num_threads,get_num_threads (Python,torch module)

Environment variables:OMP_NUM_THREADS andMKL_NUM_THREADS

For the intra-op parallelism settings,at::set_num_threads,torch.set_num_threads always take precedenceover environment variables,MKL_NUM_THREADS variable takes precedence overOMP_NUM_THREADS.

Note

parallel_info utility prints information about thread settings and can be used for debugging.Similar output can be also obtained in Python withtorch.__config__.parallel_info() call.