Note
Go to the endto download the full example code.
Creating Extensions Using NumPy and SciPy#
Created On: Mar 24, 2017 | Last Updated: Apr 25, 2023 | Last Verified: Not Verified
Author:Adam Paszke
Updated by:Adam Dziedzic
In this tutorial, we shall go through two tasks:
Create a neural network layer with no parameters.
This calls intonumpy as part of its implementation
Create a neural network layer that has learnable weights
This calls intoSciPy as part of its implementation
importtorchfromtorch.autogradimportFunction
Parameter-less example#
This layer doesn’t particularly do anything useful or mathematicallycorrect.
It is aptly namedBadFFTFunction
Layer Implementation
fromnumpy.fftimportrfft2,irfft2classBadFFTFunction(Function):@staticmethoddefforward(ctx,input):numpy_input=input.detach().numpy()result=abs(rfft2(numpy_input))returninput.new(result)@staticmethoddefbackward(ctx,grad_output):numpy_go=grad_output.numpy()result=irfft2(numpy_go)returngrad_output.new(result)# since this layer does not have any parameters, we can# simply declare this as a function, rather than as an ``nn.Module`` classdefincorrect_fft(input):returnBadFFTFunction.apply(input)
Example usage of the created layer:
input=torch.randn(8,8,requires_grad=True)result=incorrect_fft(input)print(result)result.backward(torch.randn(result.size()))print(input)
tensor([[ 5.3450, 5.5498, 11.5089, 2.8125, 4.3323], [ 2.9070, 6.2810, 9.4622, 8.6332, 7.7426], [ 8.5840, 9.2695, 6.4574, 13.0145, 10.4781], [10.9830, 6.6098, 8.7645, 14.7908, 3.9896], [ 3.2419, 12.1645, 10.5262, 8.1190, 6.3947], [10.9830, 1.9102, 6.5775, 8.1247, 3.9896], [ 8.5840, 5.3311, 8.3869, 7.0915, 10.4781], [ 2.9070, 2.9750, 8.6982, 2.8529, 7.7426]], grad_fn=<BadFFTFunctionBackward>)tensor([[-0.0431, -0.3393, -0.7665, -0.0388, 0.9029, 1.1912, 1.7363, 0.9635], [ 0.3266, -0.2551, -0.8588, -0.4930, -1.0086, 0.1345, -0.6747, -0.7447], [ 1.4749, -2.3105, 0.8239, 1.8603, -0.2318, 1.1680, -1.7512, 2.1520], [ 0.1837, 0.3386, 0.1454, -0.8777, 1.9420, -1.4555, 1.3802, -0.0673], [ 1.0813, -2.0244, -0.8719, 0.7395, 1.6491, -0.2347, 0.0360, 0.1725], [ 0.3419, -0.0599, 0.6426, -0.2108, -0.5040, 0.5801, -0.3561, -0.1282], [-2.3215, -0.4499, 0.2578, 0.5038, 0.1573, -0.7633, -0.5024, 0.0724], [ 0.7510, 0.4329, 0.4210, 1.0757, -1.2639, -1.3631, 1.7391, 0.9374]], requires_grad=True)
Parametrized example#
In deep learning literature, this layer is confusingly referredto as convolution while the actual operation is cross-correlation(the only difference is that filter is flipped for convolution,which is not the case for cross-correlation).
Implementation of a layer with learnable weights, where cross-correlationhas a filter (kernel) that represents weights.
The backward pass computes the gradientwrt the input and the gradientwrt the filter.
fromnumpyimportflipimportnumpyasnpfromscipy.signalimportconvolve2d,correlate2dfromtorch.nn.modules.moduleimportModulefromtorch.nn.parameterimportParameterclassScipyConv2dFunction(Function):@staticmethoddefforward(ctx,input,filter,bias):# detach so we can cast to NumPyinput,filter,bias=input.detach(),filter.detach(),bias.detach()result=correlate2d(input.numpy(),filter.numpy(),mode='valid')result+=bias.numpy()ctx.save_for_backward(input,filter,bias)returntorch.as_tensor(result,dtype=input.dtype)@staticmethoddefbackward(ctx,grad_output):grad_output=grad_output.detach()input,filter,bias=ctx.saved_tensorsgrad_output=grad_output.numpy()grad_bias=np.sum(grad_output,keepdims=True)grad_input=convolve2d(grad_output,filter.numpy(),mode='full')# the previous line can be expressed equivalently as:# grad_input = correlate2d(grad_output, flip(flip(filter.numpy(), axis=0), axis=1), mode='full')grad_filter=correlate2d(input.numpy(),grad_output,mode='valid')returntorch.from_numpy(grad_input),torch.from_numpy(grad_filter).to(torch.float),torch.from_numpy(grad_bias).to(torch.float)classScipyConv2d(Module):def__init__(self,filter_width,filter_height):super(ScipyConv2d,self).__init__()self.filter=Parameter(torch.randn(filter_width,filter_height))self.bias=Parameter(torch.randn(1,1))defforward(self,input):returnScipyConv2dFunction.apply(input,self.filter,self.bias)
Example usage:
module=ScipyConv2d(3,3)print("Filter and bias: ",list(module.parameters()))input=torch.randn(10,10,requires_grad=True)output=module(input)print("Output from the convolution: ",output)output.backward(torch.randn(8,8))print("Gradient for the input map: ",input.grad)
Filter and bias: [Parameter containing:tensor([[ 0.3812, -1.5893, 1.7878], [ 1.1881, 0.8153, 0.8211], [ 0.3540, 0.0449, -1.0425]], requires_grad=True), Parameter containing:tensor([[1.2485]], requires_grad=True)]Output from the convolution: tensor([[ 0.3887, 0.9207, 3.5466, 4.8022, -6.1077, -2.8013, 0.1342, -1.3506], [ 4.9075, -1.2236, 5.0911, -2.2551, -1.7956, 5.3628, 3.8459, 2.2870], [ 2.7045, -2.1232, -4.2774, 2.5828, 3.1806, 0.3914, 0.2538, 3.8699], [-4.2609, 1.5699, 0.1625, 4.8202, 3.8256, -3.4648, 2.3548, 1.6205], [ 5.7046, 3.8783, 0.8075, 3.8384, -5.7635, -1.2634, -0.2252, -2.8102], [ 2.2825, 5.7868, -4.2896, 3.0919, 2.5739, 2.2220, 0.9839, 4.2843], [ 1.9240, 1.3728, -1.0431, -3.2084, -1.2085, -1.0898, -0.3533, 1.5252], [ 0.3076, -3.0035, 1.4617, -2.3310, 4.4440, 6.0350, 2.4367, -0.3053]], grad_fn=<ScipyConv2dFunctionBackward>)Gradient for the input map: tensor([[-0.5771, 2.6675, -4.1250, 2.6139, -1.6773, 0.3955, -0.5557, 0.0867, 1.1252, -1.3681], [-1.4075, -2.1109, 0.2959, 0.3476, -2.5655, 0.8364, 1.5913, -0.7115, -2.8672, -0.0379], [ 0.5016, 1.5688, 1.5273, -2.6035, 4.8218, 0.2319, -2.0005, -2.4703, 2.6784, -0.4663], [-0.3587, -0.0090, 1.0457, -1.3076, 3.9351, -0.8206, 0.7432, 0.6396, -1.7742, 1.4561], [-0.7220, -0.4305, -1.1705, 3.2283, -3.4887, -1.2411, -5.3865, 3.3820, -1.6556, 2.6371], [-0.3009, 0.7392, 1.2937, -2.7387, 0.1762, -1.4713, 1.9263, -0.3104, 4.0068, -3.6319], [-0.1364, 1.2278, -0.7634, -0.5907, 0.0560, 0.3107, -2.3470, -1.6941, -0.3301, -2.5376], [-0.7421, 0.0324, -0.3173, -2.6714, -5.9965, -1.7365, 0.8882, -1.0008, -4.6804, 2.5464], [-0.5831, -1.2099, -1.0323, -0.8480, 1.8486, 1.5325, 0.9417, -0.4623, 0.9867, 1.3658], [-0.1153, -0.3216, 0.2024, 1.1352, 0.6837, -0.9398, -1.3357, 1.0806, 0.7216, -0.9412]])
Check the gradients:
fromtorch.autograd.gradcheckimportgradcheckmoduleConv=ScipyConv2d(3,3)input=[torch.randn(20,20,dtype=torch.double,requires_grad=True)]test=gradcheck(moduleConv,input,eps=1e-6,atol=1e-4)print("Are the gradients correct: ",test)
Are the gradients correct: True
Total running time of the script: (0 minutes 0.617 seconds)