Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

empty_permute decomposition#2698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to ourterms of service andprivacy statement. We’ll occasionally send you account related emails.

Already on GitHub?Sign in to your account

Merged
apbose merged 1 commit intomainfromempty_permuted_decomposition
Apr 17, 2024
Merged

Conversation

@apbose
Copy link
Collaborator

This is an extension to support aten::empty_like.

@github-actionsgithub-actionsbot added component: testsIssues re: Tests component: loweringIssues re: The lowering / preprocessing passes component: api [Python]Issues re: Python API component: dynamoIssues relating to the `torch.compile` or `torch._dynamo.export` paths labelsMar 19, 2024
@apboseapboseforce-pushed theempty_permuted_decomposition branch fromdcfe61d to6abe7ceCompareApril 5, 2024 00:15
Comment on lines +443 to +450
fx_graph=torch.fx.symbolic_trace(emptyLike())
unexpected_ops_seen,expected_ops_unseen=lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Could you show a printout of what the original and final graphs look like in this case? I want to verify that there is not a circular issue whereempty_permuted generatesempty_like, and vice versa

Copy link
CollaboratorAuthor

@apboseapboseApr 12, 2024
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

With theempty_permute decomposition the graph is this
Pre-AOT Autograd graph:=============

graph():   %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]   %add : [num_users=2] = call_function[target=torch.ops.aten.add](args = (%l_x_, %l_x_), kwargs = {})   %empty_like_default : [num_users=1] = call_function[target=torch.ops.aten.empty_like.default](args = (%add,), kwargs = {})   %add_1 : [num_users=1] = call_function[target=operator.add](args = (%empty_like_default, %add), kwargs = {})   return (add_1,)

Post-AOT Autograd graph:=======

graph():   %arg0_1 : [num_users=1] = placeholder[target=arg0_1]   %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg0_1,), kwargs = {})   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%clone, %clone), kwargs = {})   %empty : [num_users=1] = call_function[target=torch.ops.aten.empty.memory_format](args = ([3, 2],), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})   %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%empty, [0, 1]), kwargs = {})   %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%permute, %add), kwargs = {})   return (add_1,)

Graph after constant folding:

graph():   %arg0_1 : [num_users=1] = placeholder[target=arg0_1]   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg0_1), kwargs = {})   %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]   %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_frozen_param0, %add), kwargs = {})   return (add_1,)

Post-lowering passes Autograd graph:=======

graph():   %arg0_1 : [num_users=1] = placeholder[target=arg0_1]   %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg0_1), kwargs = {})   %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]   %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_frozen_param0, %add), kwargs = {})   return (add_1,)

Without the decomposition, the graph is
Pre-AOT Autograd graph:=============

graph():    %l_x_ : torch.Tensor [num_users=1] = placeholder[target=L_x_]    %add : [num_users=2] = call_function[target=torch.ops.aten.add](args = (%l_x_, %l_x_), kwargs = {})    %empty_like_default : [num_users=1] = call_function[target=torch.ops.aten.empty_like.default](args = (%add,), kwargs = {})    %add_1 : [num_users=1] = call_function[target=operator.add](args = (%empty_like_default, %add), kwargs = {})    return (add_1,)

Post-AOT Autograd graph:=======

graph():    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]    %clone : [num_users=1] = call_function[target=torch.ops.aten.clone.default](args = (%arg0_1,), kwargs = {})    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%clone, %clone), kwargs = {})    %empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([3, 2], [0, 1]), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%empty_permuted, %add), kwargs = {})    return (add_1,)

Graph after constant folding:

graph():    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg0_1), kwargs = {})    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_frozen_param0, %add), kwargs = {})    return (add_1,)

Post-lowering passes Autograd graph:=======

graph():    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]    %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %arg0_1), kwargs = {})    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]    %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%_frozen_param0, %add), kwargs = {})    return (add_1,)

Soempty_like decomposes intoempty_permute which decomposes intoempty.memory_format. The above test does not give error, even thoughempty.memory_format is not supported since constant folding removes the op.

I am working on empty.memory_format in PR#2745

gs-olive reacted with thumbs up emoji
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

In the above example, the Pre-AOT graph shows:

   %empty_like_default : [num_users=1] = call_function[target=torch.ops.aten.empty_like.default](args = (%add,), kwargs = {})

Since there is only one argument inargs, what isempty_permute = args[1] defined as in the decomposition for that case?

Copy link
CollaboratorAuthor

@apboseapboseApr 16, 2024
edited
Loading

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

In the above case with the AOT decomposition, the above operation decomposes to

 %empty_permuted : [num_users=1] = call_function[target=torch.ops.aten.empty_permuted.default](args = ([3, 2], [0, 1]), kwargs = {dtype: torch.float32, layout: torch.strided, device: cuda:0, pin_memory: False})

The args[1] in this case is[0,1] since it keeps the shapes in the original form.
Not sure how it gets the [0,1] exact, but I assume it must be the internal AOT lowering heuristics?

gs-olive reacted with thumbs up emoji
Copy link
Contributor

@gs-olivegs-olive left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

Overall looks good to me - added one clarifying question

Comment on lines +443 to +450
fx_graph=torch.fx.symbolic_trace(emptyLike())
unexpected_ops_seen,expected_ops_unseen=lower_graph_testing(
fx_graph,
inputs,
expected_ops=expected_ops,
unexpected_ops=unexpected_ops,
min_block_size=1,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others.Learn more.

In the above example, the Pre-AOT graph shows:

   %empty_like_default : [num_users=1] = call_function[target=torch.ops.aten.empty_like.default](args = (%add,), kwargs = {})

Since there is only one argument inargs, what isempty_permute = args[1] defined as in the decomposition for that case?

@apboseapbose merged commit0b29987 intomainApr 17, 2024
peri044 pushed a commit that referenced this pull requestApr 19, 2024
laikhtewari pushed a commit that referenced this pull requestMay 24, 2024
Sign up for freeto join this conversation on GitHub. Already have an account?Sign in to comment

Reviewers

1 more reviewer

@gs-olivegs-olivegs-olive approved these changes

Reviewers whose approvals may not affect merge requirements

Assignees

No one assigned

Labels

cla signedcomponent: api [Python]Issues re: Python APIcomponent: dynamoIssues relating to the `torch.compile` or `torch._dynamo.export` pathscomponent: loweringIssues re: The lowering / preprocessing passescomponent: testsIssues re: Tests

Projects

None yet

Milestone

No milestone

Development

Successfully merging this pull request may close these issues.

4 participants

@apbose@gs-olive@facebook-github-bot

[8]ページ先頭

©2009-2025 Movatter.jp