- Notifications
You must be signed in to change notification settings - Fork26.3k
vmap Onboarding Lab
The goal of this lab is to give first hand experience with writing vmap rules (aka batching rules) and adding them to PyTorch. An alternative to this lab is taking on real tasks to add batching rules to PyTorch.
The deliverable is a stack of PRs (that won't be merged into master) containing all of the code for the different sections below. The task will be considered finished when the PR is accepted by the reviewers and you can just close it at that time.
For this lab, we'll be writing a batching rule for a new operator,simple_mul that is similar totorch.mul, but has more restrictions on its input types. Concretely,simple_mul looks like the following:
defsimple_mul(x:Tensor,y:Tensor):returntorch.mul(x,y)
For the sake of learning, please do not read the existing batching rule for torch.mul. When the lab is complete, the following test cases should work:
importtorchfromfunctorchimportvmap# The dimension being vmapped overB=2B1=3foropin [torch.mul,torch.simple_mul]:# A) Simple casex=torch.randn(B)y=torch.randn(B)vmap(op)(x,y)# B) vmap over some Tensorsx=torch.randn(3)y=torch.randn(B)vmap(op,in_dims=(None,0))(x,y)# C) More complicated casex=torch.randn(4,3)y=torch.randn(3,B)vmap(op,in_dims=(None,1))(x,y)# D) Nested vmapx=torch.randn(B)y=torch.randn(B1)vmap(vmap(op, (0,None)), (None,0))(x,y)
For example, in case A,vmap(torch.mul)(x, y) can be rewritten astorch.mul(x, y).
Write a function in Python with the following signature:
defmul_batched(in_dims:Tuple[Optional[int],Optional[int]],x:Tensor,y:Tensor)->Tensor:pass
mul_batched(in_dims, x, y) should return the same thing asvmap(simple_mul, in_dims)(x, y).Do not use vmap in the definition ofmul_batched.
Test it out with the above cases.
Write a brand new CompositeImplicitAutograd function innative_functions.yaml calledsimple_mul.
Addan OpInfo entryfor it to test it. Don't useBinaryUfuncInfo (that's beyond the scope of this lab) as the constructor; just use OpInfoand provide the above test cases (and more, if you wish) as the sample inputs.
When this step is done, you'll have atorch.simple_mul operator that is callable from Python as well as an OpInfo test for it.Run
pytest test/functorch/test_vmap.py -v -k "simple_mul"to verify that the OpInfo is correctly hooked up to our vmap testing.
Mark the operator as CompositeExplicitAutograd and add vmap support fortorch.simple_mul inBatchRulesBinaryOps.cpp. The goal is to get the followingtests to pass:
- TestVmapOperators.test_vmap_exhaustive_simple_mul*
- TestVmapOperators.test_op_has_batch_rule_simple_mul*Please do this step by adding the following C++ function, which should be a transcription of your Python
mul_batchedinto C++.
std::tuple<Tensor,optional<int64_t>>simple_mul_batch_rule(const Tensor& x, optional<int64_t> x_bdim,const Tensor& y, optional<int64_t> y_bdim) {// your code here.}
Avoid using the BINARY_POINTWISE macro (that would solve the problem trivially).
Unit 8: function transforms/Training Loops (Optional) -composable function transforms (aka torch.func, functorch)
I would love to contribute to PyTorch!