Loading a TorchScript Model in C++¶
Created On: Sep 14, 2018 | Last Updated: Dec 02, 2024 | Last Verified: Nov 05, 2024
Warning
TorchScript is no longer in active development.
As its name suggests, the primary interface to PyTorch is the Pythonprogramming language. While Python is a suitable and preferred language formany scenarios requiring dynamism and ease of iteration, there are equally manysituations where precisely these properties of Python are unfavorable. Oneenvironment in which the latter often applies isproduction – the land oflow latencies and strict deployment requirements. For production scenarios, C++is very often the language of choice, even if only to bind it into anotherlanguage like Java, Rust or Go. The following paragraphs will outline the pathPyTorch provides to go from an existing Python model to a serializedrepresentation that can beloaded andexecuted purely from C++, with nodependency on Python.
Step 1: Converting Your PyTorch Model to Torch Script¶
A PyTorch model’s journey from Python to C++ is enabled byTorch Script, a representation of a PyTorchmodel that can be understood, compiled and serialized by the Torch Scriptcompiler. If you are starting out from an existing PyTorch model written in thevanilla “eager” API, you must first convert your model to Torch Script. In themost common cases, discussed below, this requires only little effort. If youalready have a Torch Script module, you can skip to the next section of thistutorial.
There exist two ways of converting a PyTorch model to Torch Script. The firstis known astracing, a mechanism in which the structure of the model iscaptured by evaluating it once using example inputs, and recording the flow ofthose inputs through the model. This is suitable for models that make limiteduse of control flow. The second approach is to add explicit annotations to yourmodel that inform the Torch Script compiler that it may directly parse andcompile your model code, subject to the constraints imposed by the Torch Scriptlanguage.
Tip
You can find the complete documentation for both of these methods, as well asfurther guidance on which to use, in the officialTorch Scriptreference.
Converting to Torch Script via Tracing¶
To convert a PyTorch model to Torch Script via tracing, you must pass aninstance of your model along with an example input to thetorch.jit.trace
function. This will produce atorch.jit.ScriptModule
object with the traceof your model evaluation embedded in the module’sforward
method:
importtorchimporttorchvision# An instance of your model.model=torchvision.models.resnet18()# An example input you would normally provide to your model's forward() method.example=torch.rand(1,3,224,224)# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.traced_script_module=torch.jit.trace(model,example)
The tracedScriptModule
can now be evaluated identically to a regularPyTorch module:
In[1]:output=traced_script_module(torch.ones(1,3,224,224))In[2]:output[0,:5]Out[2]:tensor([-0.2698,-0.0381,0.4023,-0.3010,-0.0448],grad_fn=<SliceBackward>)
Converting to Torch Script via Annotation¶
Under certain circumstances, such as if your model employs particular forms ofcontrol flow, you may want to write your model in Torch Script directly andannotate your model accordingly. For example, say you have the followingvanilla Pytorch model:
importtorchclassMyModule(torch.nn.Module):def__init__(self,N,M):super(MyModule,self).__init__()self.weight=torch.nn.Parameter(torch.rand(N,M))defforward(self,input):ifinput.sum()>0:output=self.weight.mv(input)else:output=self.weight+inputreturnoutput
Because theforward
method of this module uses control flow that isdependent on the input, it is not suitable for tracing. Instead, we can convertit to aScriptModule
.In order to convert the module to theScriptModule
, one needs tocompile the module withtorch.jit.script
as follows:
classMyModule(torch.nn.Module):def__init__(self,N,M):super(MyModule,self).__init__()self.weight=torch.nn.Parameter(torch.rand(N,M))defforward(self,input):ifinput.sum()>0:output=self.weight.mv(input)else:output=self.weight+inputreturnoutputmy_module=MyModule(10,20)sm=torch.jit.script(my_module)
If you need to exclude some methods in yournn.Module
because they use Python features that TorchScript doesn’t support yet,you could annotate those with@torch.jit.ignore
sm
is an instance ofScriptModule
that is ready for serialization.
Step 2: Serializing Your Script Module to a File¶
Once you have aScriptModule
in your hands, either from tracing orannotating a PyTorch model, you are ready to serialize it to a file. Later on,you’ll be able to load the module from this file in C++ and execute it withoutany dependency on Python. Say we want to serialize theResNet18
model shownearlier in the tracing example. To perform this serialization, simply callsaveon the module and pass it a filename:
traced_script_module.save("traced_resnet_model.pt")
This will produce atraced_resnet_model.pt
file in your working directory.If you also would like to serializesm
, callsm.save("my_module_model.pt")
We have now officially left the realm of Python and are ready to cross over to the sphereof C++.
Step 3: Loading Your Script Module in C++¶
To load your serialized PyTorch model in C++, your application must depend onthe PyTorch C++ API – also known asLibTorch. The LibTorch distributionencompasses a collection of shared libraries, header files and CMake buildconfiguration files. While CMake is not a requirement for depending onLibTorch, it is the recommended approach and will be well supported into thefuture. For this tutorial, we will be building a minimal C++ application usingCMake and LibTorch that simply loads and executes a serialized PyTorch model.
A Minimal C++ Application¶
Let’s begin by discussing the code to load a module. The following will alreadydo:
#include<torch/script.h> // One-stop header.#include<iostream>#include<memory>intmain(intargc,constchar*argv[]){if(argc!=2){std::cerr<<"usage: example-app <path-to-exported-script-module>\n";return-1;}torch::jit::script::Modulemodule;try{// Deserialize the ScriptModule from a file using torch::jit::load().module=torch::jit::load(argv[1]);}catch(constc10::Error&e){std::cerr<<"error loading the model\n";return-1;}std::cout<<"ok\n";}
The<torch/script.h>
header encompasses all relevant includes from theLibTorch library necessary to run the example. Our application accepts the filepath to a serialized PyTorchScriptModule
as its only command line argumentand then proceeds to deserialize the module using thetorch::jit::load()
function, which takes this file path as input. In return we receive atorch::jit::script::Module
object. We will examine how to execute it in a moment.
Depending on LibTorch and Building the Application¶
Assume we stored the above code into a file calledexample-app.cpp
. AminimalCMakeLists.txt
to build it could look as simple as:
cmake_minimum_required(VERSION3.0FATAL_ERROR)project(custom_ops)find_package(TorchREQUIRED)add_executable(example-appexample-app.cpp)target_link_libraries(example-app"${TORCH_LIBRARIES}")set_property(TARGETexample-appPROPERTYCXX_STANDARD17)
The last thing we need to build the example application is the LibTorchdistribution. You can always grab the latest stable release from thedownloadpage on the PyTorch website. If you download and unzipthe latest archive, you should receive a folder with the following directorystructure:
libtorch/bin/include/lib/share/
The
lib/
folder contains the shared libraries you must link against,The
include/
folder contains header files your program will need to include,The
share/
folder contains the necessary CMake configuration to enable the simplefind_package(Torch)
command above.
Tip
On Windows, debug and release builds are not ABI-compatible. If you plan tobuild your project in debug mode, please try the debug version of LibTorch.Also, make sure you specify the correct configuration in thecmake--build.
line below.
The last step is building the application. For this, assume our exampledirectory is laid out like this:
example-app/CMakeLists.txtexample-app.cpp
We can now run the following commands to build the application from within theexample-app/
folder:
mkdirbuildcdbuildcmake-DCMAKE_PREFIX_PATH=/path/to/libtorch..cmake--build.--configRelease
where/path/to/libtorch
should be the full path to the unzipped LibTorchdistribution. If all goes well, it will look something like this:
root@4b5a67132e81:/example-app#mkdirbuildroot@4b5a67132e81:/example-app#cdbuildroot@4b5a67132e81:/example-app/build#cmake-DCMAKE_PREFIX_PATH=/path/to/libtorch..--TheCcompileridentificationisGNU5.4.0--TheCXXcompileridentificationisGNU5.4.0--CheckforworkingCcompiler:/usr/bin/cc--CheckforworkingCcompiler:/usr/bin/cc--works--DetectingCcompilerABIinfo--DetectingCcompilerABIinfo-done--DetectingCcompilefeatures--DetectingCcompilefeatures-done--CheckforworkingCXXcompiler:/usr/bin/c++--CheckforworkingCXXcompiler:/usr/bin/c++--works--DetectingCXXcompilerABIinfo--DetectingCXXcompilerABIinfo-done--DetectingCXXcompilefeatures--DetectingCXXcompilefeatures-done--Lookingforpthread.h--Lookingforpthread.h-found--Lookingforpthread_create--Lookingforpthread_create-notfound--Lookingforpthread_createinpthreads--Lookingforpthread_createinpthreads-notfound--Lookingforpthread_createinpthread--Lookingforpthread_createinpthread-found--FoundThreads:TRUE--Configuringdone--Generatingdone--Buildfileshavebeenwrittento:/example-app/buildroot@4b5a67132e81:/example-app/build#makeScanningdependenciesoftargetexample-app[50%]BuildingCXXobjectCMakeFiles/example-app.dir/example-app.cpp.o[100%]LinkingCXXexecutableexample-app[100%]Builttargetexample-app
If we supply the path to the tracedResNet18
modeltraced_resnet_model.pt
we created earlierto the resultingexample-app
binary, we should be rewarded with a friendly“ok”. Please note, if try to run this example withmy_module_model.pt
you will get an error saying thatyour input is of an incompatible shape.my_module_model.pt
expects 1D instead of 4D.
root@4b5a67132e81:/example-app/build#./example-app<path_to_model>/traced_resnet_model.ptok
Step 4: Executing the Script Module in C++¶
Having successfully loaded our serializedResNet18
in C++, we are now just acouple lines of code away from executing it! Let’s add those lines to our C++application’smain()
function:
// Create a vector of inputs.std::vector<torch::jit::IValue>inputs;inputs.push_back(torch::ones({1,3,224,224}));// Execute the model and turn its output into a tensor.at::Tensoroutput=module.forward(inputs).toTensor();std::cout<<output.slice(/*dim=*/1,/*start=*/0,/*end=*/5)<<'\n';
The first two lines set up the inputs to our model. We create a vector oftorch::jit::IValue
(a type-erased value typescript::Module
methodsaccept and return) and add a single input. To create the input tensor, we usetorch::ones()
, the equivalent totorch.ones
in the C++ API. We thenrun thescript::Module
’sforward
method, passing it the input vector wecreated. In return we get a newIValue
, which we convert to a tensor bycallingtoTensor()
.
Tip
To learn more about functions liketorch::ones
and the PyTorch C++ API ingeneral, refer to its documentation athttps://pytorch.org/cppdocs. ThePyTorch C++ API provides near feature parity with the Python API, allowingyou to further manipulate and process tensors just like in Python.
In the last line, we print the first five entries of the output. Since wesupplied the same input to our model in Python earlier in this tutorial, weshould ideally see the same output. Let’s try it out by re-compiling ourapplication and running it with the same serialized model:
root@4b5a67132e81:/example-app/build#makeScanningdependenciesoftargetexample-app[50%]BuildingCXXobjectCMakeFiles/example-app.dir/example-app.cpp.o[100%]LinkingCXXexecutableexample-app[100%]Builttargetexample-approot@4b5a67132e81:/example-app/build#./example-apptraced_resnet_model.pt-0.2698-0.03810.4023-0.3010-0.0448[Variable[CPUFloatType]{1,5}]
For reference, the output in Python previously was:
tensor([-0.2698,-0.0381,0.4023,-0.3010,-0.0448],grad_fn=<SliceBackward>)
Looks like a good match!
Tip
To move your model to GPU memory, you can writemodel.to(at::kCUDA);
.Make sure the inputs to a model are also living in CUDA memoryby callingtensor.to(at::kCUDA)
, which will return a new tensor in CUDAmemory.
Step 5: Getting Help and Exploring the API¶
This tutorial has hopefully equipped you with a general understanding of aPyTorch model’s path from Python to C++. With the concepts described in thistutorial, you should be able to go from a vanilla, “eager” PyTorch model, to acompiledScriptModule
in Python, to a serialized file on disk and – toclose the loop – to an executablescript::Module
in C++.
Of course, there are many concepts we did not cover. For example, you may findyourself wanting to extend yourScriptModule
with a custom operatorimplemented in C++ or CUDA, and executing this custom operator inside yourScriptModule
loaded in your pure C++ production environment. The good newsis: this is possible, and well supported! For now, you can explorethis folderfor examples, and we will follow up with a tutorial shortly. In the time being,the following links may be generally helpful:
The Torch Script reference:https://pytorch.org/docs/master/jit.html
The PyTorch C++ API documentation:https://pytorch.org/cppdocs/
The PyTorch Python API documentation:https://pytorch.org/docs/
As always, if you run into any problems or have questions, you can use ourforum orGitHub issues to get in touch.