Movatterモバイル変換


[0]ホーム

URL:


Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commitb6df043

Browse files
bowangbjfacebook-github-bot
authored andcommitted
Add torch.nn.init.uniform_ operator to ShardedTensor. (#63997)
Summary:Pull Requestresolved:#63997Use torch_function to extend torch.nn.init.uniform_The Init is done in SPMD fashion. Note that ideally we want to aggregate sharded tensors into a global tensor, init it and reshard. It's fine to run it SPMD since uniform is I.I.D indepenent and identifically distributed.Also enable unit test for test_linear.py for OSS testTest Plan:a) Unit Test(pytorch) ... $ python test/distributed/_sharded_tensor/ops/test_init.py TestShardedTensorNNInit --v(pytorch) ... $ python test/distributed/_sharded_tensor/ops/test_linear.py --v (before runs this command is no-op)or b) Manual run: Instruction here:https://docs.google.com/document/d/1_m1Hdo5w51-hhPlZ_F8Y6PIWrN7UgJZqiSpARYvhsaE/edit#Imported from OSSReviewed By: pritamdamania87, anjali411Differential Revision: D30563017fbshipit-source-id: d1859f7682235bcb44515efc69ca92bc5e34fce1
1 parentbdb889a commitb6df043

File tree

7 files changed

+107
-2
lines changed

7 files changed

+107
-2
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
importsys
2+
importtorch
3+
4+
fromtorch.distributedimport_sharded_tensor
5+
fromtorch.distributed._sharding_specimport (
6+
ChunkShardingSpec,
7+
)
8+
fromtorch.testing._internal.common_distributedimport (
9+
requires_nccl,
10+
skip_if_lt_x_gpu,
11+
)
12+
fromtorch.testing._internal.distributed._sharded_tensorimport (
13+
ShardedTensorTestBase,
14+
with_comms,
15+
)
16+
fromtorch.testing._internal.common_utilsimport (
17+
TEST_WITH_DEV_DBG_ASAN,
18+
run_tests,
19+
)
20+
21+
ifTEST_WITH_DEV_DBG_ASAN:
22+
print("Skip dev-asan as torch + multiprocessing spawn have known issues",file=sys.stderr)
23+
sys.exit(0)
24+
25+
classTestShardedTensorNNInit(ShardedTensorTestBase):
26+
""" Testing torch.nn.init functions for ShardedTensor """
27+
28+
@with_comms
29+
@skip_if_lt_x_gpu(4)
30+
@requires_nccl()
31+
deftest_init_sharded_tensor_with_uniform(self):
32+
""" Test torch.nn.init.uniform_(ShardedTensor, a, b) """
33+
34+
spec=ChunkShardingSpec(
35+
dim=0,
36+
placements=[
37+
"rank:0/cuda:0",
38+
"rank:1/cuda:1",
39+
"rank:2/cuda:2",
40+
"rank:3/cuda:3",
41+
],
42+
)
43+
h,w=8,2
44+
expected_h=2
45+
expected_device=torch.device(f"cuda:{self.rank}")
46+
a,b=10,20
47+
48+
seed=1234
49+
dtype=torch.double
50+
51+
sharded_tensor=_sharded_tensor.empty(spec,h,w,dtype=dtype)
52+
self.assertEqual(1,len(sharded_tensor.local_shards()))
53+
54+
# Clone local tensor to ensure torch.nn.init starts from the same input
55+
local_tensor_clone=torch.clone(sharded_tensor.local_shards()[0].tensor)
56+
torch.manual_seed(seed)
57+
torch.nn.init.uniform_(sharded_tensor,a=a,b=b)
58+
59+
torch.manual_seed(seed)
60+
torch.nn.init.uniform_(local_tensor_clone,a=a,b=b)
61+
self.assertEqual(local_tensor_clone,sharded_tensor.local_shards()[0].tensor)
62+
63+
64+
if__name__=='__main__':
65+
run_tests()

‎test/distributed/_sharded_tensor/ops/test_linear.py‎

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
)
1313
fromtorch.testing._internal.common_utilsimport (
1414
TEST_WITH_DEV_DBG_ASAN,
15+
run_tests,
1516
)
1617
fromtorch.testing._internal.distributed._sharded_tensorimport (
1718
ShardedTensorTestBase,
@@ -85,3 +86,6 @@ def test_sharded_linear_rowwise(self):
8586
# Test uneven split.
8687
self._run_sharded_linear(spec, [5,19], [19,11],1)
8788
self._run_sharded_linear(spec, [5,21], [21,11],1)
89+
90+
if__name__=='__main__':
91+
run_tests()

‎test/run_test.py‎

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def skip_test_p(name: str) -> bool:
199199
"distributed/elastic/multiprocessing/api_test",
200200
"distributed/_sharded_tensor/test_sharded_tensor",
201201
"distributed/_sharded_tensor/ops/test_embedding",
202+
"distributed/_sharded_tensor/ops/test_init",
202203
"distributed/_sharded_tensor/ops/test_linear",
203204
]+FSDP_TEST
204205

@@ -209,6 +210,7 @@ def skip_test_p(name: str) -> bool:
209210
"distributed/rpc/cuda/test_tensorpipe_agent",
210211
"distributed/_sharded_tensor/test_sharded_tensor",
211212
"distributed/_sharded_tensor/ops/test_embedding",
213+
"distributed/_sharded_tensor/ops/test_init",
212214
"distributed/_sharded_tensor/ops/test_linear",
213215
"test_determination",
214216
"test_multiprocessing",
@@ -345,6 +347,7 @@ def skip_test_p(name: str) -> bool:
345347
"distributed/_sharding_spec/test_sharding_spec",
346348
"distributed/_sharded_tensor/test_sharded_tensor",
347349
"distributed/_sharded_tensor/ops/test_embedding",
350+
"distributed/_sharded_tensor/ops/test_init",
348351
"distributed/_sharded_tensor/ops/test_linear",
349352
]+ [testfortestinTESTSiftest.startswith("distributed/fsdp")]
350353

‎torch/distributed/_sharded_tensor/api.py‎

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
get_chunked_dim_size,
2828
)
2929
fromtorch.typesimportNumber
30-
from .opsimportsharded_embedding,sharded_linear
30+
from .opsimportsharded_embedding,sharded_linear,uniform_
3131

3232
# Tracking for sharded tensor objects.
3333
_sharded_tensor_lock=threading.Lock()
@@ -638,7 +638,8 @@ def __torch_function__(self, func, types, args=(), kwargs=None):
638638
returnsharded_linear(types,args,kwargs,self._process_group)
639639
iffunc==torch.nn.functional.embedding:
640640
returnsharded_embedding(types,args,kwargs,self._process_group)
641-
641+
eliffunc==torch.nn.init.uniform_:
642+
returnuniform_(types,args,kwargs)
642643
raiseRuntimeError(
643644
f"torch function '{func.__name__}', with args:{args} and "
644645
f"kwargs:{kwargs} not supported for ShardedTensor!")
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1+
from .initimportuniform_
12
from .linearimportsharded_linear
23
from .embeddingimportsharded_embedding
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
importtorch
2+
3+
defvalidate_param(param,param_name):
4+
ifparamisNone:
5+
raiseValueError(f"param:{param_name} shouldn't be None!")
6+
7+
defuniform_(types,args=(),kwargs=None):
8+
r"""
9+
Fills the Tensor in sharded_tensor.local_shards with values drawn from the uniform
10+
distribution :math:`\mathcal{U}(a, b)`.
11+
Args:
12+
sharded_tensor: tensor sharded across devices
13+
a: the lower bound of the uniform distribution
14+
b: the upper bound of the uniform distribution
15+
"""
16+
validate_param(kwargs,"kwargs")
17+
sharded_tensor=kwargs["tensor"]
18+
validate_param(sharded_tensor,"sharded_tensor")
19+
a=kwargs['a']
20+
validate_param(a,"a")
21+
b=kwargs['b']
22+
validate_param(b,"b")
23+
24+
forshardinsharded_tensor.local_shards():
25+
torch.nn.init.uniform_(shard.tensor,a=a,b=b)
26+
returnsharded_tensor

‎torch/nn/init.py‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44
fromtorchimportTensor
55
importtorch
66

7+
from ..overridesimport (
8+
has_torch_function_variadic,
9+
handle_torch_function)
710

811
# These no_grad_* functions are necessary as wrappers around the parts of these
912
# functions that use `with torch.no_grad()`. The JIT doesn't support context
@@ -132,6 +135,8 @@ def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor:
132135
>>> w = torch.empty(3, 5)
133136
>>> nn.init.uniform_(w)
134137
"""
138+
ifhas_torch_function_variadic(tensor,a,b):
139+
returnhandle_torch_function(uniform_, (tensor,a,b),tensor=tensor,a=a,b=b)
135140
return_no_grad_uniform_(tensor,a,b)
136141

137142

0 commit comments

Comments
 (0)

[8]ページ先頭

©2009-2025 Movatter.jp