Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitc770435

Browse files
anshul-sipytorchmergebot
authored andcommitted
[dtensor][partial] redistributes _NormPartial to replicate for necessary pointwise_ops (#170035)
**Summary:** While linearities are correct for regular Partial tensors, they don't apply to NormPartial tensors. We fix this by redistributing for all ops necessary when the placement is NormPartial. The math showing that mul and div scalar ops don't need redistribution when the scalar value is non negative is shown below. The iterative process to arriving at this PR can be viewed in#167813.<img width="449" height="807" alt="image" src="https://github.com/user-attachments/assets/5e42d65d-d05f-43eb-9d69-9e663fb1f1eb" />**Test Cases**1. pytest test/distributed/tensor/test_pointwise_ops.py -k test_mul_div_scalar_norm_partial2. pytest test/distributed/tensor/test_pointwise_ops.py -k test_add_scalar_norm_partialPull Requestresolved:#170035Approved by:https://github.com/wconstabghstack dependencies:#170030
1 parente270faa commitc770435

File tree

2 files changed

+59
-7
lines changed

2 files changed

+59
-7
lines changed

‎test/distributed/tensor/test_pointwise_ops.py‎

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,6 +483,33 @@ def test_mul_div_scalar_partial(self):
483483
expected=expected/4.0
484484
self.assertEqual(res,expected)
485485

486+
@with_comms
487+
deftest_mul_div_scalar_norm_partial(self):
488+
mesh=self.build_device_mesh()
489+
aten=torch.ops.aten
490+
local_tensor=torch.tensor([1.0,1.0,7.0,7.0])
491+
dt=distribute_tensor(local_tensor,mesh, [Shard(0)])
492+
493+
norm=dt.norm()
494+
self.assertTrue(isinstance(norm._spec.placements[0],_NormPartial))
495+
496+
res=aten.mul.Scalar(norm,2)
497+
self.assertTrue(isinstance(res._spec.placements[0],_NormPartial))
498+
res=res.redistribute(dt.device_mesh,placements=[Replicate()])
499+
self.assertEqual(res,20)
500+
501+
res=aten.div.Scalar(norm,2)
502+
self.assertTrue(isinstance(res._spec.placements[0],_NormPartial))
503+
res=res.redistribute(dt.device_mesh,placements=[Replicate()])
504+
self.assertEqual(res,5)
505+
506+
res=aten.mul.Scalar(norm,-2)
507+
self.assertTrue(res._spec.placements[0].is_replicate())
508+
509+
res=aten.div.Scalar(norm,-2)
510+
self.assertEqual(res,-5)
511+
self.assertTrue(res._spec.placements[0].is_replicate())
512+
486513
@with_comms
487514
deftest_add_scalar_partial(self):
488515
mesh=self.build_device_mesh()

‎torch/distributed/tensor/_ops/_pointwise_ops.py‎

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
StrategyType,
1313
TupleStrategy,
1414
)
15+
fromtorch.distributed.tensor._ops._math_opsimport_NormPartial
1516
fromtorch.distributed.tensor._ops.registrationimportregister_op_strategy
1617
fromtorch.distributed.tensor._ops.utilsimport (
1718
generate_redistribute_costs,
@@ -569,9 +570,21 @@ def common_pointwise_strategy(
569570
out_placements.append(Shard(new_shard_dim))
570571
elifisinstance(placement,Partial):
571572
is_scalar_arg=any(isinstance(arg,_Number)forarginargs_schema)
572-
propagate_partial=not (
573-
opinredistribute_partial_opsandis_scalar_arg
574-
)
573+
propagate_partial=False
574+
575+
# ordering matters here since NormPartial is a subclass of Partial
576+
ifisinstance(placement,_NormPartial):
577+
# explanation for args_schema[1] >= 0 can be found in summary
578+
# https://github.com/pytorch/pytorch/pull/170035
579+
propagate_partial= (
580+
opinnorm_partial_avoidable_redistribute_ops
581+
andargs_schema[1]>=0# pyre-ignore[unsupported-operation]
582+
)
583+
584+
elifisinstance(placement,Partial):
585+
propagate_partial=not (
586+
opinp_sum_scalar_redistribute_opsandis_scalar_arg
587+
)
575588

576589
# Check if this partial type should be preserved
577590
ifpreserve_partialisnotNoneandplacement.is_partial(
@@ -682,12 +695,24 @@ def common_pointwise_strategy(
682695
returnpointwise_strategy
683696

684697

685-
redistribute_partial_ops= {aten.add.Tensor,aten.add_.Tensor}
698+
p_sum_scalar_redistribute_ops= {aten.add.Tensor,aten.add_.Tensor}
699+
700+
norm_partial_avoidable_redistribute_ops= {
701+
aten.div.Scalar,
702+
aten.div_.Scalar,
703+
aten.mul.Scalar,
704+
aten.mul_.Scalar,
705+
}
686706

687707
foropinlinear_pointwise_ops:
688-
register_op_strategy(op,schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(
689-
linear_pointwise_strategy
690-
)
708+
ifopinnorm_partial_avoidable_redistribute_ops:
709+
register_op_strategy(
710+
op,schema_info=RuntimeSchemaInfo(1,static_kwargkey=["out"])
711+
)(linear_pointwise_strategy)
712+
else:
713+
register_op_strategy(
714+
op,schema_info=RuntimeSchemaInfo(static_kwargkey=["out"])
715+
)(linear_pointwise_strategy)
691716

692717
foropinpartial_preserving_ops:
693718
register_op_strategy(op,schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))(

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp