Exploring TorchRec sharding#
Created On: May 10, 2022 | Last Updated: May 13, 2022 | Last Verified: Nov 05, 2024
This tutorial will mainly cover the sharding schemes of embedding tablesviaEmbeddingPlanner andDistributedModelParallel API andexplore the benefits of different sharding schemes for the embeddingtables by explicitly configuring them.
Installation#
Requirements: - python >= 3.7
We highly recommend CUDA when using torchRec. If using CUDA: - cuda >=11.0
# install conda to make installying pytorch with cudatoolkit 11.3 easier.!sudormMiniconda3-py37_4.9.2-Linux-x86_64.shMiniconda3-py37_4.9.2-Linux-x86_64.sh.*!sudowgethttps://repo.anaconda.com/miniconda/Miniconda3-py37_4.9.2-Linux-x86_64.sh!sudochmod+xMiniconda3-py37_4.9.2-Linux-x86_64.sh!sudobash./Miniconda3-py37_4.9.2-Linux-x86_64.sh-b-f-p/usr/local
# install pytorch with cudatoolkit 11.3!sudocondainstallpytorchcudatoolkit=11.3-cpytorch-nightly-y
Installing torchRec will also installFBGEMM, a collection of CUDAkernels and GPU enabled operations to run
# install torchrec!pip3installtorchrec-nightly
Install multiprocess which works with ipython to for multi-processingprogramming within colab
!pip3installmultiprocess
The following steps are needed for the Colab runtime to detect the addedshared libraries. The runtime searches for shared libraries in /usr/lib,so we copy over the libraries which were installed in /usr/local/lib/.This is a very necessary step, only in the colab runtime.
!sudocp/usr/local/lib/lib*/usr/lib/
Restart your runtime at this point for the newly installed packagesto be seen. Run the step below immediately after restarting so thatpython knows where to look for packages.Always run this step afterrestarting the runtime.
importsyssys.path=['','/env/python','/usr/local/lib/python37.zip','/usr/local/lib/python3.7','/usr/local/lib/python3.7/lib-dynload','/usr/local/lib/python3.7/site-packages','./.local/lib/python3.7/site-packages']
Distributed Setup#
Due to the notebook enviroment, we cannot runSPMD program here but wecan do multiprocessing inside the notebook to mimic the setup. Usersshould be responsible for setting up their ownSPMD launcher when usingTorchrec. We setup our environment so that torch distributed basedcommunication backend can work.
importosimporttorchimporttorchrecos.environ["MASTER_ADDR"]="localhost"os.environ["MASTER_PORT"]="29500"
Constructing our embedding model#
Here we use TorchRec offering ofEmbeddingBagCollectionto construct our embedding bag model with embedding tables.
Here, we create an EmbeddingBagCollection (EBC) with four embeddingbags. We have two types of tables: large tables and small tablesdifferentiated by their row size difference: 4096 vs 1024. Each table isstill represented by 64 dimension embedding.
We configure theParameterConstraints data structure for the tables,which provides hints for the model parallel API to help decide thesharding and placement strategy for the tables. In TorchRec, we support*table-wise: place the entire table on one device; *row-wise: shard the table evenly by row dimension and place oneshard on each device of the communication world; *column-wise:shard the table evenly by embedding dimension, and place one shard oneach device of the communication world; *table-row-wise: specialsharding optimized for intra-host communication for available fastintra-machine device interconnect, e.g. NVLink; *data_parallel:replicate the tables for every device;
Note how we initially allocate the EBC on device “meta”. This will tellEBC to not allocate memory yet.
fromtorchrec.distributed.planner.typesimportParameterConstraintsfromtorchrec.distributed.embedding_typesimportEmbeddingComputeKernelfromtorchrec.distributed.typesimportShardingTypefromtypingimportDictlarge_table_cnt=2small_table_cnt=2large_tables=[torchrec.EmbeddingBagConfig(name="large_table_"+str(i),embedding_dim=64,num_embeddings=4096,feature_names=["large_table_feature_"+str(i)],pooling=torchrec.PoolingType.SUM,)foriinrange(large_table_cnt)]small_tables=[torchrec.EmbeddingBagConfig(name="small_table_"+str(i),embedding_dim=64,num_embeddings=1024,feature_names=["small_table_feature_"+str(i)],pooling=torchrec.PoolingType.SUM,)foriinrange(small_table_cnt)]defgen_constraints(sharding_type:ShardingType=ShardingType.TABLE_WISE)->Dict[str,ParameterConstraints]:large_table_constraints={"large_table_"+str(i):ParameterConstraints(sharding_types=[sharding_type.value],)foriinrange(large_table_cnt)}small_table_constraints={"small_table_"+str(i):ParameterConstraints(sharding_types=[sharding_type.value],)foriinrange(small_table_cnt)}constraints={**large_table_constraints,**small_table_constraints}returnconstraints
ebc=torchrec.EmbeddingBagCollection(device="cuda",tables=large_tables+small_tables)
DistributedModelParallel in multiprocessing#
Now, we have a single process execution function for mimicking onerank’s work duringSPMDexecution.
This code will shard the model collectively with other processes andallocate memories accordingly. It first sets up process groups and doembedding table placement using planner and generate sharded model usingDistributedModelParallel.
defsingle_rank_execution(rank:int,world_size:int,constraints:Dict[str,ParameterConstraints],module:torch.nn.Module,backend:str,)->None:importosimporttorchimporttorch.distributedasdistfromtorchrec.distributed.embeddingbagimportEmbeddingBagCollectionSharderfromtorchrec.distributed.model_parallelimportDistributedModelParallelfromtorchrec.distributed.plannerimportEmbeddingShardingPlanner,Topologyfromtorchrec.distributed.typesimportModuleSharder,ShardingEnvfromtypingimportcastdefinit_distributed_single_host(rank:int,world_size:int,backend:str,# pyre-fixme[11]: Annotation `ProcessGroup` is not defined as a type.)->dist.ProcessGroup:os.environ["RANK"]=f"{rank}"os.environ["WORLD_SIZE"]=f"{world_size}"dist.init_process_group(rank=rank,world_size=world_size,backend=backend)returndist.group.WORLDifbackend=="nccl":device=torch.device(f"cuda:{rank}")torch.cuda.set_device(device)else:device=torch.device("cpu")topology=Topology(world_size=world_size,compute_device="cuda")pg=init_distributed_single_host(rank,world_size,backend)planner=EmbeddingShardingPlanner(topology=topology,constraints=constraints,)sharders=[cast(ModuleSharder[torch.nn.Module],EmbeddingBagCollectionSharder())]plan:ShardingPlan=planner.collective_plan(module,sharders,pg)sharded_model=DistributedModelParallel(module,env=ShardingEnv.from_process_group(pg),plan=plan,sharders=sharders,device=device,)print(f"rank:{rank},sharding plan:{plan}")returnsharded_model
Multiprocessing Execution#
Now let’s execute the code in multi-processes representing multiple GPUranks.
importmultiprocessdefspmd_sharing_simulation(sharding_type:ShardingType=ShardingType.TABLE_WISE,world_size=2,):ctx=multiprocess.get_context("spawn")processes=[]forrankinrange(world_size):p=ctx.Process(target=single_rank_execution,args=(rank,world_size,gen_constraints(sharding_type),ebc,"nccl"),)p.start()processes.append(p)forpinprocesses:p.join()assert0==p.exitcode
Table Wise Sharding#
Now let’s execute the code in two processes for 2 GPUs. We can see inthe plan print that how our tables are sharded across GPUs. Each nodewill have one large table and one small which shows our planner triesfor load balance for the embedding tables. Table-wise is the de-factorgo-to sharding schemes for many small-medium size tables for loadbalancing over the devices.
spmd_sharing_simulation(ShardingType.TABLE_WISE)
rank:1,shardingplan:{'':{'large_table_0':ParameterSharding(sharding_type='table_wise',compute_kernel='batched_fused',ranks=[0],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[4096,64],placement=rank:0/cuda:0)])),'large_table_1':ParameterSharding(sharding_type='table_wise',compute_kernel='batched_fused',ranks=[1],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[4096,64],placement=rank:1/cuda:1)])),'small_table_0':ParameterSharding(sharding_type='table_wise',compute_kernel='batched_fused',ranks=[0],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[1024,64],placement=rank:0/cuda:0)])),'small_table_1':ParameterSharding(sharding_type='table_wise',compute_kernel='batched_fused',ranks=[1],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[1024,64],placement=rank:1/cuda:1)]))}}rank:0,shardingplan:{'':{'large_table_0':ParameterSharding(sharding_type='table_wise',compute_kernel='batched_fused',ranks=[0],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[4096,64],placement=rank:0/cuda:0)])),'large_table_1':ParameterSharding(sharding_type='table_wise',compute_kernel='batched_fused',ranks=[1],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[4096,64],placement=rank:1/cuda:1)])),'small_table_0':ParameterSharding(sharding_type='table_wise',compute_kernel='batched_fused',ranks=[0],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[1024,64],placement=rank:0/cuda:0)])),'small_table_1':ParameterSharding(sharding_type='table_wise',compute_kernel='batched_fused',ranks=[1],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[1024,64],placement=rank:1/cuda:1)]))}}
Explore other sharding modes#
We have initially explored what table-wise sharding would look like andhow it balances the tables placement. Now we explore sharding modes withfiner focus on load balance: row-wise. Row-wise is specificallyaddressing large tables which a single device cannot hold due to thememory size increase from large embedding row numbers. It can addressthe placement of the super large tables in your models. Users can seethat in theshard_sizes section in the printed plan log, the tablesare halved by row dimension to be distributed onto two GPUs.
spmd_sharing_simulation(ShardingType.ROW_WISE)
rank:1,shardingplan:{'':{'large_table_0':ParameterSharding(sharding_type='row_wise',compute_kernel='batched_fused',ranks=[0,1],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[2048,64],placement=rank:0/cuda:0),ShardMetadata(shard_offsets=[2048,0],shard_sizes=[2048,64],placement=rank:1/cuda:1)])),'large_table_1':ParameterSharding(sharding_type='row_wise',compute_kernel='batched_fused',ranks=[0,1],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[2048,64],placement=rank:0/cuda:0),ShardMetadata(shard_offsets=[2048,0],shard_sizes=[2048,64],placement=rank:1/cuda:1)])),'small_table_0':ParameterSharding(sharding_type='row_wise',compute_kernel='batched_fused',ranks=[0,1],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[512,64],placement=rank:0/cuda:0),ShardMetadata(shard_offsets=[512,0],shard_sizes=[512,64],placement=rank:1/cuda:1)])),'small_table_1':ParameterSharding(sharding_type='row_wise',compute_kernel='batched_fused',ranks=[0,1],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[512,64],placement=rank:0/cuda:0),ShardMetadata(shard_offsets=[512,0],shard_sizes=[512,64],placement=rank:1/cuda:1)]))}}rank:0,shardingplan:{'':{'large_table_0':ParameterSharding(sharding_type='row_wise',compute_kernel='batched_fused',ranks=[0,1],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[2048,64],placement=rank:0/cuda:0),ShardMetadata(shard_offsets=[2048,0],shard_sizes=[2048,64],placement=rank:1/cuda:1)])),'large_table_1':ParameterSharding(sharding_type='row_wise',compute_kernel='batched_fused',ranks=[0,1],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[2048,64],placement=rank:0/cuda:0),ShardMetadata(shard_offsets=[2048,0],shard_sizes=[2048,64],placement=rank:1/cuda:1)])),'small_table_0':ParameterSharding(sharding_type='row_wise',compute_kernel='batched_fused',ranks=[0,1],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[512,64],placement=rank:0/cuda:0),ShardMetadata(shard_offsets=[512,0],shard_sizes=[512,64],placement=rank:1/cuda:1)])),'small_table_1':ParameterSharding(sharding_type='row_wise',compute_kernel='batched_fused',ranks=[0,1],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[512,64],placement=rank:0/cuda:0),ShardMetadata(shard_offsets=[512,0],shard_sizes=[512,64],placement=rank:1/cuda:1)]))}}
Column-wise on the other hand, address the load imbalance problems fortables with large embedding dimensions. We will split the tablevertically. Users can see that in theshard_sizes section in theprinted plan log, the tables are halved by embedding dimension to bedistributed onto two GPUs.
spmd_sharing_simulation(ShardingType.COLUMN_WISE)
rank:0,shardingplan:{'':{'large_table_0':ParameterSharding(sharding_type='column_wise',compute_kernel='batched_fused',ranks=[0,1],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[4096,32],placement=rank:0/cuda:0),ShardMetadata(shard_offsets=[0,32],shard_sizes=[4096,32],placement=rank:1/cuda:1)])),'large_table_1':ParameterSharding(sharding_type='column_wise',compute_kernel='batched_fused',ranks=[0,1],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[4096,32],placement=rank:0/cuda:0),ShardMetadata(shard_offsets=[0,32],shard_sizes=[4096,32],placement=rank:1/cuda:1)])),'small_table_0':ParameterSharding(sharding_type='column_wise',compute_kernel='batched_fused',ranks=[0,1],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[1024,32],placement=rank:0/cuda:0),ShardMetadata(shard_offsets=[0,32],shard_sizes=[1024,32],placement=rank:1/cuda:1)])),'small_table_1':ParameterSharding(sharding_type='column_wise',compute_kernel='batched_fused',ranks=[0,1],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[1024,32],placement=rank:0/cuda:0),ShardMetadata(shard_offsets=[0,32],shard_sizes=[1024,32],placement=rank:1/cuda:1)]))}}rank:1,shardingplan:{'':{'large_table_0':ParameterSharding(sharding_type='column_wise',compute_kernel='batched_fused',ranks=[0,1],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[4096,32],placement=rank:0/cuda:0),ShardMetadata(shard_offsets=[0,32],shard_sizes=[4096,32],placement=rank:1/cuda:1)])),'large_table_1':ParameterSharding(sharding_type='column_wise',compute_kernel='batched_fused',ranks=[0,1],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[4096,32],placement=rank:0/cuda:0),ShardMetadata(shard_offsets=[0,32],shard_sizes=[4096,32],placement=rank:1/cuda:1)])),'small_table_0':ParameterSharding(sharding_type='column_wise',compute_kernel='batched_fused',ranks=[0,1],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[1024,32],placement=rank:0/cuda:0),ShardMetadata(shard_offsets=[0,32],shard_sizes=[1024,32],placement=rank:1/cuda:1)])),'small_table_1':ParameterSharding(sharding_type='column_wise',compute_kernel='batched_fused',ranks=[0,1],sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0,0],shard_sizes=[1024,32],placement=rank:0/cuda:0),ShardMetadata(shard_offsets=[0,32],shard_sizes=[1024,32],placement=rank:1/cuda:1)]))}}
Fortable-row-wise, unfortuately we cannot simulate it due to itsnature of operating under multi-host setup. We will present a pythonSPMD example in the futureto train models withtable-row-wise.
With data parallel, we will repeat the tables for all devices.
spmd_sharing_simulation(ShardingType.DATA_PARALLEL)
rank:0,shardingplan:{'':{'large_table_0':ParameterSharding(sharding_type='data_parallel',compute_kernel='batched_dense',ranks=[0,1],sharding_spec=None),'large_table_1':ParameterSharding(sharding_type='data_parallel',compute_kernel='batched_dense',ranks=[0,1],sharding_spec=None),'small_table_0':ParameterSharding(sharding_type='data_parallel',compute_kernel='batched_dense',ranks=[0,1],sharding_spec=None),'small_table_1':ParameterSharding(sharding_type='data_parallel',compute_kernel='batched_dense',ranks=[0,1],sharding_spec=None)}}rank:1,shardingplan:{'':{'large_table_0':ParameterSharding(sharding_type='data_parallel',compute_kernel='batched_dense',ranks=[0,1],sharding_spec=None),'large_table_1':ParameterSharding(sharding_type='data_parallel',compute_kernel='batched_dense',ranks=[0,1],sharding_spec=None),'small_table_0':ParameterSharding(sharding_type='data_parallel',compute_kernel='batched_dense',ranks=[0,1],sharding_spec=None),'small_table_1':ParameterSharding(sharding_type='data_parallel',compute_kernel='batched_dense',ranks=[0,1],sharding_spec=None)}}