Part 2: GPU-Accelerated inference

Neural Demapper Overview

We now discuss how to integrate the TensorRT engine for inference into the OAI stack. To keep the inference latency as low as possible[Gadiyar2023],[Kundu2023], we use CUDA graphs[Gray2019] to launch the TensorRT inference engine.

You will learn about:

  • How to accelerate the neural demapper using TensorRT

  • How to pre- and post-process input and output data using CUDA

  • How to use CUDA graphs for latency reductions

For details on efficient memory management when offloading compute-intensive functions to the GPU, we refer to theGPU-Accelerated LDPC Decoding tutorial.

Demapper Implementation Overview

The neural demapper is implemented in Tensorflow and exported to TensorRT, the source code of the inference logic can be found inplugins/neural_demapper/src/runtime/trt_demapper.cpp. The implementation will be explained in the following sections.

The TRT demapper receives noisy input symbols from the OpenAirInterface stack via the functiontrt_demapper_decode(), which chunks a given array of symbols into batches of maximum sizeMAX_BLOCK_LEN and then callstrt_demapper_decode_block() to carry out the actual inference on each batch. To leverage data-parallel execution on the GPU, inference is performed for batches of symbols and multiple threads in parallel. The output of the neural demapper is passed back in the form ofnum_bits_per_symbol LLRs per input symbol.

To run the TensorRT inference engine on the givenint16_t-quantized data, we dequantize input symbols to half-precision floating-point format on the GPU using a data-parallel CUDA kernel (seenorm_int16_symbols_to_float16()), and re-quantize output LLRs using another CUDA kernel (seefloat16_llrs_to_int16()).

Setting up the TensorRT Inference Engine

To be compatible with the multi-threaded OpenAirInterface implementation, we load the neural demapper network into a TensorRTICudaEngine once and share the inference engine between multipleIExecutionContext objects which are created per worker thread. We store the global and per-thread state, respectively, as follows:

 1#include"NvInfer.h" 2#include<cuda_fp16.h> 3 4// global 5staticIRuntime*runtime=nullptr; 6staticICudaEngine*engine=nullptr; 7 8// per thread 9structTRTContext{10cudaStream_tdefault_stream=0;// asynchronous CUDA command stream11IExecutionContext*trt=nullptr;// TensorRT execution context12void*prealloc_memory=nullptr;// memory block for temporary per-inference data13__half*input_buffer=nullptr;// device-side network inputs after CUDA pre-processing14__half*output_buffer=nullptr;// device-side network output before CUDA pre-processing1516int16_t*symbol_buffer=nullptr;// write-through buffer for symbols written by CPU and read by GPU17int16_t*magnitude_buffer=nullptr;// write-through buffer for magnitude estimates written by CPU and read by GPU18int16_t*llr_buffer=nullptr;// host-cached buffer for llr estimates written by GPU and read by CPU1920// list of thread contexts for shutdown21TRTContext*next_initialized_context=nullptr;22};23static__threadTRTContextthread_context={};

We call the following global initialization routine on program startup:

 1staticcharconst*trt_weight_file="models/neural_demapper_qam16_2.plan";// Training result / trtexec output 2staticbooltrt_normalized_inputs=true; 3 4extern"C"TRTContext*trt_demapper_init(){ 5if(runtime)// lazy, global 6return&trt_demapper_init_context(); 7 8printf("Initializing TRT runtime\n"); 9runtime=createInferRuntime(logger);10printf("Loading TRT engine %s (normalized inputs: %d)\n",trt_weight_file,trt_normalized_inputs);11std::vector<char>modelData=readModelFromFile(trt_weight_file);12engine=runtime->deserializeCudaEngine(modelData.data(),modelData.size());1314return&trt_demapper_init_context();15}1617// Utilities1819std::vector<char>readModelFromFile(charconst*filepath){20std::vector<char>bytes;21FILE*f=fopen(filepath,"rb");22if(!f){23logger.log(Logger::Severity::kERROR,filepath);24returnbytes;25}26fseek(f,0,SEEK_END);27bytes.resize((size_t)ftell(f));28fseek(f,0,SEEK_SET);29if(bytes.size()!=fread(bytes.data(),1,bytes.size(),f))30logger.log(Logger::Severity::kWARNING,filepath);31fclose(f);32returnbytes;33}3435structLogger:publicILogger36{37voidlog(Severityseverity,constchar*msg)noexceptoverride38{39// suppress info-level messages40if(severity<=Severity::kWARNING)41printf("TRT %s: %s\n",severity==Severity::kWARNING?"WARNING":"ERROR",msg);42}43};44staticLoggerlogger;

On startup of each worker thread, we initialize the per-thread contexts as follows:

 1TRTContext&trt_demapper_init_context(){ 2auto&context=thread_context; 3if(context.trt)// lazy 4returncontext; 5 6printf("Initializing TRT context (TID %d)\n",(int)gettid()); 7 8// create execution context with its own pre-allocated temporary memory attached 9context.trt=engine->createExecutionContextWithoutDeviceMemory();10size_tpreallocSize=engine->getDeviceMemorySize();11CHECK_CUDA(cudaMalloc(&context.prealloc_memory,preallocSize));12context.trt->setDeviceMemory(context.prealloc_memory);1314// create own asynchronous CUDA command stream for this thread15CHECK_CUDA(cudaStreamCreateWithFlags(&context.default_stream,cudaStreamNonBlocking));1617// allocate neural network input and output buffers (device access memory)18cudaMalloc((void**)&context.input_buffer,sizeof(*context.input_buffer)*4*MAX_BLOCK_LEN);19cudaMalloc((void**)&context.output_buffer,sizeof(*context.output_buffer)*MAX_BITS_PER_SYMBOL*MAX_BLOCK_LEN);2021// OAI decoder input buffers that can be written and read with unified addressing from CPU and GPU, respectively22// note: GPU reads are uncached, but read-once coalesced23cudaHostAlloc((void**)&context.symbol_buffer,sizeof(*context.symbol_buffer)*2*MAX_BLOCK_LEN,cudaHostAllocMapped|cudaHostAllocWriteCombined);24cudaHostAlloc((void**)&context.magnitude_buffer,sizeof(*context.magnitude_buffer)*2*MAX_BLOCK_LEN,cudaHostAllocMapped|cudaHostAllocWriteCombined);25// OAI decoder output buffers that can be written and read with unified addressing from GPU and CPU, respectively26// note: GPU writes are uncached, but write-once coalesced27cudaHostAlloc((void**)&context.llr_buffer,sizeof(*context.llr_buffer)*MAX_BITS_PER_SYMBOL*MAX_BLOCK_LEN,cudaHostAllocMapped);2829// keep track of active thread contexts for shutdown30TRTContext*self=&context;31__atomic_exchange(&initialized_thread_contexts,&self,&self->next_initialized_context,__ATOMIC_ACQ_REL);3233returncontext;34}

Running Batched Inference

If decoder symbols are already available in half-precision floating-point format, running the TensorRT inference engine is as simple as performing one call to enqueue the corresponding inference commands on the asynchronous CUDA command stream of the calling thread’s context:

1voidtrt_demapper_run(TRTContext*context,cudaStream_tstream,__halfconst*inputs,size_tnumInputs,size_tnumInputComponents,__half*outputs){2if(stream==0)3stream=context->default_stream;45context.trt->setTensorAddress("y",(void*)inputs);6context.trt->setInputShape("y",Dims2(numInputs,numInputComponents));7context.trt->setTensorAddress("output_1",outputs);8context.trt->enqueueV3(stream);9}

Converting Data Types between Host and Device

In the OAI 5G stack, received symbols come in from the host side in quantizedint16_t format, together with a channel magnitude estimate.In order to convert inputs to half-precision floating-point format, we first copy the symbols to a pinned memory buffermapped_symbols that resides in unified addressable memory, and then run a CUDA kernel for dequantization and normalization on the GPU.After inference, the conversion back to quantized LLRs follows the same pattern, first a CUDA kernel quantizes the half-precision floating-point inference outputs, then the quantized data written by the GPU is read by the CPU using the unified addressable memory buffermapped_outputs. Note that the CUDA command stream runs asynchronously, therefore it needs to be synchronized with the calling thread before accessing the output data.

 1extern"C"voidtrt_demapper_decode_block(TRTContext*context_,cudaStream_tstream,int16_tconst*in_symbols,int16_tconst*in_mags,size_tnum_symbols, 2__halfconst*mapped_symbols,__halfconst*mapped_mags,size_tnum_batch_symbols, 3int16_t*outputs,uint32_tnum_bits_per_symbol,__half*mapped_outputs){ 4auto&context=*context_; 5 6memcpy((void*)mapped_symbols,in_symbols,sizeof(*in_symbols)*2*num_symbols); 7memcpy((void*)mapped_mags,in_mags,sizeof(*in_mags)*2*num_symbols); 8 9size_tnum_in_components;10if(trt_normalized_inputs){11norm_int16_symbols_to_float16(stream,mapped_symbols,mapped_mags,num_batch_symbols,12(uint16_t*)context.input_buffer,1);13num_in_components=2;14}15else{16[...]17num_in_components=4;18}1920trt_demapper_run(&context,stream,recording?nullptr:context.input_buffer,block_size,num_in_components,recording?nullptr:context.output_buffer);2122float16_llrs_to_int16(stream,(uint16_tconst*)context.output_buffer,num_batch_symbols,23mapped_outputs,num_bits_per_symbol);2425CHECK_CUDA(cudaStreamSynchronize(stream));26memcpy(outputs,mapped_outputs,sizeof(*outputs)*num_bits_per_symbol*num_symbols);27}

The CUDA kernel for normalization runs in a straight-forward 1D CUDA grid, reading the tuples ofint16_t-quantized components that make up each complex value in a coalesced (consecutive) way, as oneint32_t value each. Then, the symbol values are normalized with respect to the magnitude values and again written in a coalesced way, fusing each complex symbol into one__half2 value:

1inline__host____device__intblocks_for(uint32_telements,intblock_size){2returnint(uint32_t(elements+(block_size-1))/uint32_t(block_size));3}

The CUDA kernel for re-quantization of output LLRs works analogously, converting half-precision floating-point LLR tuples to quantizedint16_t values by fixed-point scaling and rounding:

Demapper Integration in OAI

Note

Ensure that you have built the TRTengine in the first part of the tutorial.

In order to mount the TensorRT models and config files, you need to extend theconfig/common/docker-compose.override.yaml file:

1services:2oai-gnb:3volumes:4-../../plugins/neural_demapper/models/:/opt/oai-gnb/models5-../../plugins/neural_demapper/config/demapper_trt.config:/opt/oai-gnb/demapper_trt.config

Pre-trained models are available inplugins/neural_demapper/models.

The TRT config file format has the following schema:

<trt_engine_file:string># file name of the TensorRT engine<trt_normalized_inputs:int># flag to indicate if the inputs are normalized

For example, the following config file will use the TensorRT enginemodels/neural_demapper.2xfloat16.plan and normalize the inputs:

model/neural_demapper.2xfloat16.plan1

Running the Demapper

The neural demapper is implemented as shared library (seePlugins & Data Acquisition) which can be loaded using the OAI shared library loader. The demapper can now be used as a drop-in replacement for the QAM-16 default implementation. The demapper can be loaded when running the gNB via the followingGNB_EXTRA_OPTIONS in theconfig/<config_name>/.env file of the config folder.

GNB_EXTRA_OPTIONS=--loader.demapper.shlibversion_trt--MACRLCs.[0].dl_max_mcs10--MACRLCs.[0].ul_max_mcs10

We limit the MCS indices to 10 in order to stay within the QAM-16 modulation order.

Congratulations! You have now successfully implemented demapping using a neural network.

You can track the GPU load via

# on DGX Spark$nvidia-smi# on Jetson$jtop

Implementation Aspects

In the following section, we focus on various technical aspects of the CUDA implementation and the performance implications of different memory transfer patterns and command scheduling optimization.

Memory Management

Similar to theGPU-Accelerated LDPC Decoding tutorial, we use the shared system memory architecture to avoid the bottleneck of costly memory transfers on traditional split-memory platforms.

As previously covered in theGPU-Accelerated LDPC Decoding tutorial, optimizing memory operations is essential for real-time performance. For the neural demapper implementation, we use the same efficient approach of page-locked memory (viacudaHostAlloc()) to enable direct GPU-CPU memory sharing. This allows for simplememcpy() operations instead of complex memory management calls, with host caching enabled for CPU access while device caching is disabled for direct memory access. This approach is particularly well-suited for the small buffer sizes used in neural demapping, avoiding the overhead of traditional GPU memory management methods likecudaMemcpyAsync() orcudaMallocManaged().

For comparison, we show both variants side-by-side in the following inference code, where the latency-optimized code path is the one withUSE_UNIFIED_MEMORY defined:

 1extern"C"voidtrt_demapper_decode_block(TRTContext*context_,cudaStream_tstream,int16_tconst*in_symbols,int16_tconst*in_mags,size_tnum_symbols, 2__halfconst*mapped_symbols,__halfconst*mapped_mags,size_tnum_batch_symbols, 3int16_t*outputs,uint32_tnum_bits_per_symbol,__half*mapped_outputs){ 4auto&context=*context_; 5 6#if defined(USE_UNIFIED_MEMORY) 7memcpy((void*)mapped_symbols,in_symbols,sizeof(*in_symbols)*2*num_symbols); 8memcpy((void*)mapped_mags,in_mags,sizeof(*in_mags)*2*num_symbols); 9#else10cudaMemcpyAsync((void*)mapped_symbols,in_symbols,sizeof(*in_symbols)*2*num_symbols,cudaMemcpyHostToDevice,stream);11cudaMemcpyAsync((void*)mapped_mags,in_mags,sizeof(*in_mags)*2*num_symbols,cudaMemcpyHostToDevice,stream);12#endif1314size_tnum_in_components;15if(trt_normalized_inputs){16norm_int16_symbols_to_float16(stream,mapped_symbols,mapped_mags,num_batch_symbols,17(uint16_t*)context.input_buffer,1);18num_in_components=2;19}20else{21[...]22num_in_components=4;23}2425trt_demapper_run(&context,stream,recording?nullptr:context.input_buffer,block_size,num_in_components,recording?nullptr:context.output_buffer);2627float16_llrs_to_int16(stream,(uint16_tconst*)context.output_buffer,num_batch_symbols,28mapped_outputs,num_bits_per_symbol);2930#if defined(USE_UNIFIED_MEMORY)31// note: synchronize the asynchronous command queue before accessing from the host32CHECK_CUDA(cudaStreamSynchronize(stream));33memcpy(outputs,mapped_outputs,sizeof(*outputs)*num_bits_per_symbol*num_symbols);34#else35cudaMemcpyAsync(outputs,mapped_outputs,sizeof(*outputs)*num_bits_per_symbol*num_symbols,cudaMemcpyDeviceToHost,stream);36// note: synchronize after the asynchronous command queue has executed the copy to host37CHECK_CUDA(cudaStreamSynchronize(stream));38#endif39}

CUDA Graph Optimization

CUDA command graph APIs[Gray2019] were introduced to frontload the overhead of scheduling repetitive sequences of compute kernels on the GPU, allowing pre-recorded, pre-optimized command sequences, such as in our case neural network inference, to be scheduled by a single API call. Thus, latency can be reduced further, focussing runtime spending on the actual computations rather than on dynamic command scheduling. We pre-record CUDA graphs including demapper inference, data pre-processing, and post-processing, for two different batch sizes, one for common small batches and one for the maximum expected parallel batch size.

Command graphs are pre-recorded per thread due to the individual intermediate storage buffers used. We run the recording at the end of thread context initialization as introduced above, for each batch size running onesize-0 inference to trigger any kind of lazy runtime allocations, and another inference on dummy inputs for the actual recording:

To extend the functiontrt_demapper_decode_block() with CUDA graph recording and execution, we introduce the following code paths whenUSE_GRAPHS is defined:

 1extern"C"voidtrt_demapper_decode_block(TRTContext*context_,cudaStream_tstream,int16_tconst*in_symbols,int16_tconst*in_mags,size_tnum_symbols, 2int16_tconst*mapped_symbols,int16_tconst*mapped_mags,size_tnum_batch_symbols, 3int16_t*outputs,uint32_tnum_bits_per_symbol,int16_t*mapped_outputs){ 4auto&context=*context_; 5 6uint32_tblock_size=num_batch_symbols>OPT_BLOCK_LEN?MAX_BLOCK_LEN:OPT_BLOCK_LEN; 7cudaGraph_t&graph=block_size==OPT_BLOCK_LEN?context.graph_opt:context.graph_max; 8cudaGraphExec_t&graphCtx=block_size==OPT_BLOCK_LEN?context.record_opt:context.record_max; 910if(num_symbols>0){11memcpy((void*)mapped_symbols,in_symbols,sizeof(*in_symbols)*2*num_symbols);12memcpy((void*)mapped_mags,in_mags,sizeof(*in_mags)*2*num_symbols);13}1415// graph capture16if(!graph){17boolrecording=false;18#ifdef USE_GRAPHS19// allow pre-allocation before recording20if(num_symbols>0){21// in pre-recording phase22CHECK_CUDA(cudaStreamBeginCapture(stream,cudaStreamCaptureModeRelaxed));23num_batch_symbols=block_size;24recording=true;25}26// else: pre-allocation phase27#endif2829size_tnum_in_components;30if(trt_normalized_inputs){31norm_int16_symbols_to_float16(stream,mapped_symbols,mapped_mags,num_batch_symbols,32(uint16_t*)context.input_buffer,1);33num_in_components=2;34}35else{36int16_symbols_to_float16(stream,mapped_symbols,num_batch_symbols,37(uint16_t*)context.input_buffer,2);38int16_symbols_to_float16(stream,mapped_mags,num_batch_symbols,39(uint16_t*)context.input_buffer+2,2);40num_in_components=4;41}4243trt_demapper_run(&context,stream,recording?nullptr:context.input_buffer,block_size,num_in_components,recording?nullptr:context.output_buffer);4445float16_llrs_to_int16(stream,(uint16_tconst*)context.output_buffer,num_batch_symbols,46mapped_outputs,num_bits_per_symbol);4748#ifdef USE_GRAPHS49if(num_symbols>0){50// in pre-recording phase51CHECK_CUDA(cudaStreamEndCapture(stream,&graph));52printf("Recorded CUDA graph (TID %d), stream %llX\n",(int)gettid(),(unsignedlonglong)stream);53}54#endif55}5657#ifdef USE_GRAPHS58if(graph&&!graphCtx){59// in pre-recording phase60CHECK_CUDA(cudaGraphInstantiate(&graphCtx,graph,0));61}62elseif(num_symbols>0){63// in runtime inference, run pre-recorded graph64cudaGraphLaunch(graphCtx,stream);65}66#endif6768CHECK_CUDA(cudaStreamSynchronize(stream));69memcpy(outputs,mapped_outputs,sizeof(*outputs)*num_bits_per_symbol*num_symbols);70}

Note

Note that CUDA graphs are currently only enabled on DGX Spark. Jetson platform support is disabled pending investigation of stability issues.

Unit tests

We have implemented unit tests usingpytest to allow testing individual parts of the implementation outside of the full 5G stack.The unit tests can be found inplugins/neural_demapper/tests/.The unit tests usenanobind to call the TensorRT and CUDA modules from Python and to test against Python-based reference implementations. For more details on how to use nanobind, please refer to thenanobind documentation.

Outlook

This was a first tutorial on accelerating neural network inference using TensorRT and CUDA graphs. The neural demapper itself is a simple network and the focus was on the integration rather than the actual error rate performance.

You are now able to deploy your own neural networks using this tutorial as a blueprint. An interesting starting point could be the5G NR PUSCH Neural Receiver tutorial, which provides a 5G compliant implementation of a neural receiver and already provides a TensorRT export of the trained network.