Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

vmap Onboarding Lab

Manuel edited this pageJul 3, 2024 ·2 revisions

The goal of this lab is to give first hand experience with writing vmap rules (aka batching rules) and adding them to PyTorch. An alternative to this lab is taking on real tasks to add batching rules to PyTorch.

The deliverable is a stack of PRs (that won't be merged into master) containing all of the code for the different sections below. The task will be considered finished when the PR is accepted by the reviewers and you can just close it at that time.

Considered Function

For this lab, we'll be writing a batching rule for a new operator,simple_mul that is similar totorch.mul, but has more restrictions on its input types. Concretely,simple_mul looks like the following:

defsimple_mul(x:Tensor,y:Tensor):returntorch.mul(x,y)

For the sake of learning, please do not read the existing batching rule for torch.mul. When the lab is complete, the following test cases should work:

importtorchfromfunctorchimportvmap# The dimension being vmapped overB=2B1=3foropin [torch.mul,torch.simple_mul]:# A) Simple casex=torch.randn(B)y=torch.randn(B)vmap(op)(x,y)# B) vmap over some Tensorsx=torch.randn(3)y=torch.randn(B)vmap(op,in_dims=(None,0))(x,y)# C) More complicated casex=torch.randn(4,3)y=torch.randn(3,B)vmap(op,in_dims=(None,1))(x,y)# D) Nested vmapx=torch.randn(B)y=torch.randn(B1)vmap(vmap(op, (0,None)), (None,0))(x,y)

I) Determine how to rewrite each of the above expressions without usingvmap.

For example, in case A,vmap(torch.mul)(x, y) can be rewritten astorch.mul(x, y).

II) Write the batching rule forsimple_mul in Python

Write a function in Python with the following signature:

defmul_batched(in_dims:Tuple[Optional[int],Optional[int]],x:Tensor,y:Tensor)->Tensor:pass

mul_batched(in_dims, x, y) should return the same thing asvmap(simple_mul, in_dims)(x, y).Do not use vmap in the definition ofmul_batched.

Test it out with the above cases.

III) Write native composite function and OpInfo

Write a brand new CompositeImplicitAutograd function innative_functions.yaml calledsimple_mul.

Addan OpInfo entryfor it to test it. Don't useBinaryUfuncInfo (that's beyond the scope of this lab) as the constructor; just use OpInfoand provide the above test cases (and more, if you wish) as the sample inputs.

When this step is done, you'll have atorch.simple_mul operator that is callable from Python as well as an OpInfo test for it.Run

pytest test/functorch/test_vmap.py -v -k "simple_mul"

to verify that the OpInfo is correctly hooked up to our vmap testing.

IV) Write custom batching rule.

Mark the operator as CompositeExplicitAutograd and add vmap support fortorch.simple_mul inBatchRulesBinaryOps.cpp. The goal is to get the followingtests to pass:

  • TestVmapOperators.test_vmap_exhaustive_simple_mul*
  • TestVmapOperators.test_op_has_batch_rule_simple_mul*Please do this step by adding the following C++ function, which should be a transcription of your Pythonmul_batched into C++.
std::tuple<Tensor,optional<int64_t>>simple_mul_batch_rule(const Tensor& x, optional<int64_t> x_bdim,const Tensor& y, optional<int64_t> y_bdim) {// your code here.}

Avoid using the BINARY_POINTWISE macro (that would solve the problem trivially).

Next

Unit 8: function transforms/Training Loops (Optional) -composable function transforms (aka torch.func, functorch)

I would love to contribute to PyTorch!

Clone this wiki locally


[8]ページ先頭

©2009-2025 Movatter.jp