Rate this Page

Note

Go to the endto download the full example code.

Building a Convolution/Batch Norm fuser with torch.compile#

Author:Horace He,Will Feng

What you will learn
  • How to register custom fusion patterns with torch.compile’s pattern matcher

Prerequisites
  • PyTorch v2.7.0

Note

This optimization only works for models in inference mode (i.e.model.eval()).However, torch.compile’s pattern matching system works for both training and inference.

First, let’s get some imports out of the way (we will be using allof these later in the code).

fromtypingimportType,Dict,Any,Tuple,Iterableimportcopyimporttorchimporttorch.nnasnndevice=torch.device("cuda"iftorch.cuda.is_available()else"cpu")

For this tutorial, we are going to create a model consisting of convolutionsand batch norms. Note that this model has some tricky components - some ofthe conv/batch norm patterns are hidden within Sequentials and one of theBatchNorms is wrapped in another Module.

classWrappedBatchNorm(nn.Module):def__init__(self):super().__init__()self.mod=nn.BatchNorm2d(1)defforward(self,x):returnself.mod(x)classM(nn.Module):def__init__(self):super().__init__()self.conv1=nn.Conv2d(1,1,1)self.bn1=nn.BatchNorm2d(1)self.conv2=nn.Conv2d(1,1,1)self.nested=nn.Sequential(nn.BatchNorm2d(1),nn.Conv2d(1,1,1),)self.wrapped=WrappedBatchNorm()defforward(self,x):x=self.conv1(x)x=self.bn1(x)x=self.conv2(x)x=self.nested(x)x=self.wrapped(x)returnxmodel=M().to(device)model.eval()

Fusing Convolution with Batch Norm#

One of the primary challenges with trying to automatically fuse convolutionand batch norm in PyTorch is that PyTorch does not provide an easy way ofaccessing the computational graph. torch.compile resolves this problem bycapturing the computational graph during compilation, allowing us to applypattern-based optimizations across the entire model, including operationsnested within Sequential modules or wrapped in custom modules.

importtorch._inductor.pattern_matcheraspmfromtorch._inductor.pattern_matcherimportregister_replacement

torch.compile will capture a graph representation of our model. Duringcompilation, modules hidden within Sequential containers and wrappedmodules are all inlined into the graph, making them available forpattern matching and optimization.

Fusing Convolution with Batch Norm#

Unlike some other fusions, fusion of convolution with batch norm does notrequire any new operators. Instead, as batch norm during inferenceconsists of a pointwise add and multiply, these operations can be “baked”into the preceding convolution’s weights. This allows us to remove the batchnorm entirely from our model! Readhttps://nenadmarkus.com/p/fusing-batchnorm-and-conv/ for further details. Thecode here is copied frompytorch/pytorchclarity purposes.

deffuse_conv_bn_eval(conv,bn):"""    Given a conv Module `A` and an batch_norm module `B`, returns a conv    module `C` such that C(x) == B(A(x)) in inference mode.    """assert(not(conv.trainingorbn.training)),"Fusion only for eval!"fused_conv=copy.deepcopy(conv)fused_conv.weight,fused_conv.bias= \fuse_conv_bn_weights(fused_conv.weight,fused_conv.bias,bn.running_mean,bn.running_var,bn.eps,bn.weight,bn.bias)returnfused_convdeffuse_conv_bn_weights(conv_w,conv_b,bn_rm,bn_rv,bn_eps,bn_w,bn_b):ifconv_bisNone:conv_b=torch.zeros_like(bn_rm)ifbn_wisNone:bn_w=torch.ones_like(bn_rm)ifbn_bisNone:bn_b=torch.zeros_like(bn_rm)bn_var_rsqrt=torch.rsqrt(bn_rv+bn_eps)conv_w=conv_w*(bn_w*bn_var_rsqrt).reshape([-1]+[1]*(len(conv_w.shape)-1))conv_b=(conv_b-bn_rm)*bn_var_rsqrt*bn_w+bn_breturntorch.nn.Parameter(conv_w),torch.nn.Parameter(conv_b)

Pattern Matching with torch.compile#

Now that we have our fusion logic, we need to register a pattern thattorch.compile’s pattern matcher will recognize and replace duringcompilation.

# Define the pattern we want to match: conv2d followed by batch_normdefconv_bn_pattern(x,conv_weight,conv_bias,bn_mean,bn_var,bn_weight,bn_bias):conv_out=torch.nn.functional.conv2d(x,conv_weight,conv_bias)bn_out=torch.nn.functional.batch_norm(conv_out,bn_mean,bn_var,bn_weight,bn_bias,training=False,eps=1e-5)returnbn_outdefconv_bn_replacement(x,conv_weight,conv_bias,bn_mean,bn_var,bn_weight,bn_bias):fused_weight,fused_bias=fuse_conv_bn_weights(conv_weight,conv_bias,bn_mean,bn_var,1e-5,bn_weight,bn_bias)returntorch.nn.functional.conv2d(x,fused_weight,fused_bias)# Example inputs are needed to trace the pattern functions.# The inputs should match the function signatures of conv_bn_pattern and conv_bn_replacement.# These are used to trace the pattern functions to create the match template.# IMPORTANT: The pattern matcher is shape-agnostic! The specific shapes you use here# don't limit what shapes will be matched - any valid conv2d->batch_norm sequence# will be matched regardless of channels, kernel size, or spatial dimensions.# - x: input tensor (batch_size, channels, height, width)# - conv_weight: (out_channels, in_channels, kernel_h, kernel_w)# - conv_bias: (out_channels,)# - bn_mean, bn_var, bn_weight, bn_bias: all have shape (num_features,) matching out_channelsexample_inputs=[torch.randn(1,1,4,4).to(device),# x: input tensortorch.randn(1,1,1,1).to(device),# conv_weight: 1 output channel, 1 input channel, 1x1 kerneltorch.randn(1).to(device),# conv_bias: 1 output channeltorch.randn(1).to(device),# bn_mean: batch norm running meantorch.randn(1).to(device),# bn_var: batch norm running variancetorch.randn(1).to(device),# bn_weight: batch norm weight (gamma)torch.randn(1).to(device),# bn_bias: batch norm bias (beta)]fromtorch._inductor.pattern_matcherimportPatternMatcherPassfromtorch._inductorimportconfig# Create a pattern matcher pass and register our patternpatterns=PatternMatcherPass()register_replacement(conv_bn_pattern,conv_bn_replacement,example_inputs,pm.fwd_only,patterns,)# Create a custom pass function that applies our patternsdefconv_bn_fusion_pass(graph):returnpatterns.apply(graph)# Set our custom pass in the configconfig.post_grad_custom_post_pass=conv_bn_fusion_pass

Note

We make some simplifications here for demonstration purposes, such as onlymatching 2D convolutions. The pattern matcher in torch.compilecan handle more complex patterns.

Testing out our Fusion Pass#

We can now run this fusion pass on our initial toy model and verify that ourresults are identical. In addition, we can print out the code for our fusedmodel and verify that there are no more batch norms.

fromtorch._dynamo.utilsimportcounters# Clear the counters before compilationcounters.clear()# Ensure pattern matcher is enabledconfig.pattern_matcher=Truefused_model=torch.compile(model,backend="inductor")inp=torch.randn(5,1,1,1).to(device)# Run the model to trigger compilation and pattern matchingwithtorch.no_grad():output=fused_model(inp)expected=model(inp)torch.testing.assert_close(output,expected)# Check how many patterns were matchedassertcounters['inductor']['pattern_matcher_count']==3,"Expected 3 conv-bn patterns to be matched"# Create a model with different shapes than our example_inputstest_model_diff_shape=nn.Sequential(nn.Conv2d(3,16,5),nn.BatchNorm2d(16),nn.ReLU(),nn.Conv2d(16,32,7),nn.BatchNorm2d(32),).to(device).eval()counters.clear()compiled_diff_shape=torch.compile(test_model_diff_shape,backend="inductor")test_input_diff_shape=torch.randn(1,3,28,28).to(device)withtorch.no_grad():compiled_diff_shape(test_input_diff_shape)# Check how many patterns were matchedassertcounters['inductor']['pattern_matcher_count']==2,"Expected 2 conv-bn patterns to be matched"

Benchmarking our Fusion on ResNet18#

We can test our fusion pass on a larger model like ResNet18 and see how muchthis pass improves inference performance.

importtorchvision.modelsasmodelsimporttimern18=models.resnet18().to(device)rn18.eval()inp=torch.randn(10,3,224,224).to(device)output=rn18(inp)defbenchmark(model,iters=20):withtorch.no_grad():for_inrange(10):model(inp)begin=time.time()for_inrange(iters):model(inp)returnstr(time.time()-begin)# Benchmark original modelprint("Original model time: ",benchmark(rn18))# Compile with our custom patterncompiled_with_pattern_matching=torch.compile(rn18,backend="inductor")# Benchmark compiled modelprint("\ntorch.compile (with conv-bn pattern matching and other fusions): ",benchmark(compiled_with_pattern_matching))############# Conclusion# ----------# As we can see, torch.compile provides a powerful way to implement# graph transformations and optimizations through pattern matching.# By registering custom patterns, we can extend torch.compile's# optimization capabilities to handle domain-specific transformations.## The conv-bn fusion demonstrated here is just one example of what's# possible with torch.compile's pattern matching system.