diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 0da8df7d8..8500e21fa 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -211,6 +211,42 @@ class ShardParams: local_metadata: List[ShardMetadata] embedding_weights: List[torch.Tensor] + def get_optimizer_single_value_shard_metadata_and_global_metadata( + table_global_metadata: ShardedTensorMetadata, + optimizer_state: torch.Tensor, + ) -> Tuple[Dict[ShardMetadata, ShardMetadata], ShardedTensorMetadata]: + table_global_shards_metadata: List[ShardMetadata] = ( + table_global_metadata.shards_metadata + ) + + table_shard_metadata_to_optimizer_shard_metadata = {} + for offset, table_shard_metadata in enumerate(table_global_shards_metadata): + table_shard_metadata_to_optimizer_shard_metadata[ + table_shard_metadata + ] = ShardMetadata( + shard_sizes=[1], # single value optimizer state + shard_offsets=[offset], # offset increases by 1 for each shard + placement=table_shard_metadata.placement, + ) + + tensor_properties = TensorProperties( + dtype=optimizer_state.dtype, + layout=optimizer_state.layout, + requires_grad=False, + ) + single_value_optimizer_st_metadata = ShardedTensorMetadata( + shards_metadata=list( + table_shard_metadata_to_optimizer_shard_metadata.values() + ), + size=torch.Size([len(table_global_shards_metadata)]), + tensor_properties=tensor_properties, + ) + + return ( + table_shard_metadata_to_optimizer_shard_metadata, + single_value_optimizer_st_metadata, + ) + def get_optimizer_rowwise_shard_metadata_and_global_metadata( table_global_metadata: ShardedTensorMetadata, optimizer_state: torch.Tensor, @@ -356,7 +392,10 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata( if optimizer_states: optimizer_state_values = tuple(optimizer_states.values()) for optimizer_state_value in optimizer_state_values: - assert table_config.local_rows == optimizer_state_value.size(0) + assert ( + table_config.local_rows == optimizer_state_value.size(0) + or optimizer_state_value.nelement() == 1 # single value state + ) optimizer_states_keys_by_table[table_config.name] = list( optimizer_states.keys() ) @@ -435,29 +474,35 @@ def get_sharded_optim_state(momentum_idx: int) -> ShardedTensor: momentum_local_shards: List[Shard] = [] optimizer_sharded_tensor_metadata: ShardedTensorMetadata - is_rowwise_optimizer_state: bool = ( - # pyre-ignore - shard_params.optimizer_states[0][momentum_idx - 1].dim() - == 1 - ) - - if is_rowwise_optimizer_state: + optim_state = shard_params.optimizer_states[0][momentum_idx - 1] # pyre-ignore[16] + if optim_state.nelement() == 1: + # single value state: one value per table + ( + table_shard_metadata_to_optimizer_shard_metadata, + optimizer_sharded_tensor_metadata, + ) = get_optimizer_single_value_shard_metadata_and_global_metadata( + table_config.global_metadata, + optim_state, + ) + elif optim_state.dim() == 1: + # rowwise state: param.shape[0] == state.shape[0], state.shape[1] == 1 ( table_shard_metadata_to_optimizer_shard_metadata, optimizer_sharded_tensor_metadata, ) = get_optimizer_rowwise_shard_metadata_and_global_metadata( table_config.global_metadata, - shard_params.optimizer_states[0][momentum_idx - 1], + optim_state, sharding_dim, is_grid_sharded, ) else: + # pointwise state: param.shape == state.shape ( table_shard_metadata_to_optimizer_shard_metadata, optimizer_sharded_tensor_metadata, ) = get_optimizer_pointwise_shard_metadata_and_global_metadata( table_config.global_metadata, - shard_params.optimizer_states[0][momentum_idx - 1], + optim_state, ) for optimizer_state, table_shard_local_metadata in zip( diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py index e9347715b..099389aa3 100644 --- a/torchrec/distributed/embedding_types.py +++ b/torchrec/distributed/embedding_types.py @@ -419,6 +419,7 @@ def sharding_types(self, compute_device_type: str) -> List[str]: types += [ ShardingType.ROW_WISE.value, ShardingType.TABLE_ROW_WISE.value, + ShardingType.GRID_SHARD.value, ] return types diff --git a/torchrec/distributed/planner/enumerators.py b/torchrec/distributed/planner/enumerators.py index 7554b74be..4845a52a8 100644 --- a/torchrec/distributed/planner/enumerators.py +++ b/torchrec/distributed/planner/enumerators.py @@ -245,7 +245,6 @@ def _filter_sharding_types( filtered_sharding_types = list( set(constrained_sharding_types) & set(allowed_sharding_types) ) - if not filtered_sharding_types: logger.warn( "No available sharding types after applying user provided " @@ -380,6 +379,7 @@ def get_partition_by_type(sharding_type: str) -> str: ShardingType.ROW_WISE.value, ShardingType.DATA_PARALLEL.value, } + multi_host_sharding_types = {ShardingType.GRID_SHARD.value} if sharding_type in device_sharding_types: return PartitionByType.DEVICE.value @@ -387,6 +387,8 @@ def get_partition_by_type(sharding_type: str) -> str: return PartitionByType.HOST.value elif sharding_type in uniform_sharding_types: return PartitionByType.UNIFORM.value + elif sharding_type in multi_host_sharding_types: + return PartitionByType.MULTI_HOST.value raise ValueError( f"Unrecognized or unsupported sharding type provided: {sharding_type}" diff --git a/torchrec/distributed/planner/partitioners.py b/torchrec/distributed/planner/partitioners.py index cb9c5c1cd..b397c5064 100644 --- a/torchrec/distributed/planner/partitioners.py +++ b/torchrec/distributed/planner/partitioners.py @@ -9,6 +9,7 @@ import copy import heapq +import itertools import logging from dataclasses import dataclass from enum import Enum @@ -242,6 +243,14 @@ def partition( for sharding_option_group in sharding_option_groups: if ( + sharding_option_group.sharding_options[0].partition_by + == PartitionByType.MULTI_HOST.value + ): + self._multi_hosts_partition(sharding_option_group, _host_level_devices) + # _multi_hosts_partition invalidates minheap_devices, force rebuild before using + minheap_devices = None + + elif ( sharding_option_group.sharding_options[0].partition_by == PartitionByType.HOST.value ): @@ -321,6 +330,136 @@ def _device_partition( minheap_devices.extend(tmp_heap) heapq.heapify(minheap_devices) + @classmethod + def _multi_hosts_partition( + cls, + sharding_option_group: ShardingOptionGroup, + _host_level_devices: List[List[DeviceHardware]], + ) -> None: + """ + Partition shards on multiple hosts. This is a greedy algorithm trying to complete partitioning on multiple hosts (sorted by perf). + First we do columnwise sharding among hosts, then tablewise-rowwise sharding within each host. + There're two cases depends on the number of hosts needed to partition shards. + + Case one: `num_host_to_allocate >= len(sorted_host_level_devices)` + We'll try to partition only once. Hosts might be selected multiple times in a circular manner. + E.g, we have 3 hosts and `num_host_to_allocate` = 4. We sort all devices on host level. The devices of hosts [0, 1, 2, 0] will be selected for uniform partitioning. + We'll update device information if success, otherwise raise a `PlannerError`. + + Case two: `num_host_to_allocate < len(sorted_host_level_devices)` + We'll try to partition with hosts `[host_index, host_index + num_host_to_allocate]` iteratively with host_index incremented by 1 each time. + 1) We sort all devices on host level. Set `host_index` = 0 + 2) We select hosts`[host_index, host_index + num_host_to_allocate]` if indexes are within range. + 3) We do uniform partitioning over all devices of the selected hosts. If we cannot partition, then we increase `host_index` by 1 and go to 2); Otherwise we go to 4) + 4) Update device information if success, otherwise raise a `PlannerError`. + + Keyword arguments: + sharding_option_group -- grouped sharding options + _host_level_devices -- devices + + Example:: + sharding_option_group.sharding_options = [ + ShardingOption(partition_by="multi_host", + shards=[ + Shards(storage=1, perf=1), + Shards(storage=1, perf=1), + Shards(storage=1, perf=1), + Shards(storage=1, perf=1), + ]), + ] + topology = Topology(world_size=6, local_world_size=2) + + # sharding_options[0] will be placed on host 1 and host 2 with the multi_hosts strategy, resulting in + + topology.devices[0].perf.total = (1,1) + topology.devices[1].perf.total = (1,1) + topology.devices[2].perf.total = (1,1) + topology.devices[3].perf.total = (1,1) + topology.devices[4].perf.total = (0,0) + topology.devices[5].perf.total = (0,0) + + """ + # TODO: for now assume just one option for multi_hosts. + if len(sharding_option_group.sharding_options) != 1: + raise PlannerError( + error_type=PlannerErrorType.PARTITION, + message=f"Unexpected length for sharding options: {len(sharding_option_group.sharding_options)}. Length needs to be 1", + ) + num_shards = sharding_option_group.sharding_options[0].num_shards + + if _host_level_devices is None: + raise PlannerError( + error_type=PlannerErrorType.PARTITION, + message="host level devices is None", + ) + + local_world_size = len(_host_level_devices[0]) + num_host_to_allocate, remainder = divmod(num_shards, local_world_size) + + if remainder > 0: + raise PlannerError( + error_type=PlannerErrorType.PARTITION, + message=f"Grid Sharding is unable to place shards equally over hosts without overlapping. {num_shards=} % {local_world_size=} != 0", + ) + + sorted_host_level_devices = _sort_devices_by_perf(_host_level_devices) + host_index = 0 + all_hosts_used = False + while True: + if num_host_to_allocate >= len(sorted_host_level_devices): + # case one: we need to use all hosts + all_hosts_used = True + devices = [] + for i in range(num_host_to_allocate): + devices.extend( + sorted_host_level_devices[i % len(sorted_host_level_devices)] + ) + else: + # case two: we can use some hosts + devices = list( + itertools.chain( + *sorted_host_level_devices[ + host_index : host_index + num_host_to_allocate + ] + ) + ) + host_index += 1 # shift to next host + host_devices = copy.deepcopy(devices) + success = True + sharding_option = sharding_option_group.sharding_options[0] + try: + if sharding_option.sharding_type == ShardingType.GRID_SHARD.value: + cls._uniform_partition([sharding_option], host_devices) + else: + raise PlannerError( + error_type=PlannerErrorType.PARTITION, + message=f"unexpected multi_host sharding type: {sharding_option.sharding_type}", + ) + except PlannerError: + success = False + if success: + # successfully found some hosts and partitioned on these hosts + # need to update the devices + for device, host_device in zip(devices, host_devices): + # check that devices and host_devices are in the same order + if device.rank != host_device.rank: + raise PlannerError( + error_type=PlannerErrorType.PARTITION, + message=f"device rank {device.rank} is not the same as device_copy rank {host_device.rank}", + ) + device.storage = host_device.storage + device.perf = host_device.perf + return + + if ( + host_index + num_host_to_allocate > len(sorted_host_level_devices) + ) or all_hosts_used: + break + raise PlannerError( + error_type=PlannerErrorType.PARTITION, + message=f"can't find hosts for sharding option group {sharding_option_group}", + ) + @classmethod def _cohost_partition( cls, @@ -358,8 +497,9 @@ def _cohost_partition( ) cls._device_partition(sharding_option, minheap_devices) else: - raise RuntimeError( - f"unexpected cohost sharding type: {sharding_option.sharding_type}" + raise PlannerError( + error_type=PlannerErrorType.PARTITION, + message=f"unexpected cohost sharding type: {sharding_option.sharding_type}", ) except PlannerError: success = False @@ -395,8 +535,9 @@ def _uniform_partition( ) -> None: for sharding_option in sharding_options: if sharding_option.num_shards != len(devices): - raise RuntimeError( - f"For a uniform partition, the number of shards ({sharding_option.num_shards}) must equal the number of devices ({len(devices)})" + raise PlannerError( + error_type=PlannerErrorType.PARTITION, + message=f"For a uniform partition, the number of shards ({sharding_option.num_shards}) must equal the number of devices ({len(devices)})", ) for i in range(len(devices)): storage_needed = cast(Storage, sharding_option.shards[i].storage) diff --git a/torchrec/distributed/planner/shard_estimators.py b/torchrec/distributed/planner/shard_estimators.py index 100a9c66e..84db2ff29 100644 --- a/torchrec/distributed/planner/shard_estimators.py +++ b/torchrec/distributed/planner/shard_estimators.py @@ -386,7 +386,10 @@ def perf_func_emb_wall_time( expected_cache_fetches=expected_cache_fetches, is_inference=is_inference, ) - elif sharding_type == ShardingType.TABLE_ROW_WISE.value: + elif ( + sharding_type == ShardingType.TABLE_ROW_WISE.value + or sharding_type == ShardingType.GRID_SHARD.value + ): shard_perf = cls._get_twrw_sharding_perf( batch_sizes=batch_sizes, world_size=world_size, @@ -729,7 +732,7 @@ def _get_twrw_sharding_perf( if is_weighted: bwd_compute = bwd_compute * weighted_feature_bwd_compute_multiplier - # for table-wise-row-wise, expected_cache_fetches per shard is / local_world_size + # for table-wise-row-wise or grid_shard, expected_cache_fetches per shard is / local_world_size prefetch_compute = cls._get_expected_cache_prefetch_time( ddr_mem_bw, expected_cache_fetches / local_world_size, @@ -984,7 +987,6 @@ def estimate( if hasattr(sharder, "fused_params") and sharder.fused_params else None ) - shard_storages = calculate_shard_storages( sharder=sharder, sharding_type=sharding_option.sharding_type, @@ -1006,7 +1008,6 @@ def estimate( is_inference=self._is_inference, multipass_prefetch_max_pass=mpp_conf.num_passes if mpp_conf else None, ) - for shard, storage in zip(sharding_option.shards, shard_storages): shard.storage = storage @@ -1256,7 +1257,10 @@ def _calculate_shard_io_sizes( num_poolings=num_poolings, is_pooled=is_pooled, ) - elif sharding_type == ShardingType.TABLE_ROW_WISE.value: + elif ( + sharding_type == ShardingType.TABLE_ROW_WISE.value + or sharding_type == ShardingType.GRID_SHARD.value # same as table row wise + ): return _calculate_twrw_shard_io_sizes( batch_sizes=batch_sizes, world_size=world_size, diff --git a/torchrec/distributed/planner/stats.py b/torchrec/distributed/planner/stats.py index 302d4a235..9455b7549 100644 --- a/torchrec/distributed/planner/stats.py +++ b/torchrec/distributed/planner/stats.py @@ -413,6 +413,7 @@ def log( f"{so.tensor.shape[1]} ({so.shards[0].size[1]})" if so.sharding_type == ShardingType.COLUMN_WISE.value or so.sharding_type == ShardingType.TABLE_COLUMN_WISE.value + or so.sharding_type == ShardingType.GRID_SHARD.value else f"{so.tensor.shape[1]}" ) sharder_cache_load_factor = ( @@ -875,6 +876,8 @@ def _get_sharding_type_abbr(sharding_type: str) -> str: return "TWRW" elif sharding_type == ShardingType.TABLE_COLUMN_WISE.value: return "TWCW" + elif sharding_type == ShardingType.GRID_SHARD.value: + return "GS" else: raise ValueError( f"Unrecognized or unsupported sharding type provided: {sharding_type}" diff --git a/torchrec/distributed/planner/tests/test_proposers.py b/torchrec/distributed/planner/tests/test_proposers.py index 9d17b1ba9..e9ca17905 100644 --- a/torchrec/distributed/planner/tests/test_proposers.py +++ b/torchrec/distributed/planner/tests/test_proposers.py @@ -169,16 +169,16 @@ def test_greedy_two_table(self) -> None: ("table_1", "row_wise", "fused"), ], [ + ("table_0", "grid_shard", "fused"), ("table_1", "row_wise", "fused"), - ("table_0", "data_parallel", "dense"), ], [ - ("table_1", "table_row_wise", "fused"), + ("table_1", "row_wise", "fused"), ("table_0", "data_parallel", "dense"), ], [ + ("table_1", "table_row_wise", "fused"), ("table_0", "data_parallel", "dense"), - ("table_1", "data_parallel", "dense"), ], ] @@ -349,10 +349,9 @@ def test_grid_search_three_table(self) -> None: - fused_uvm DP will have 1 possible compute kernel: dense So the total number of pruned options will be: - (num_sharding_types - 1) * 3 + 1 = 16 + (num_sharding_types - 1) * 3 + 1 = 19 """ - # NOTE - remove -2 from sharding type length once grid sharding in planner is added - num_pruned_options = (len(ShardingType) - 2) * 3 + 1 + num_pruned_options = (len(ShardingType) - 1) * 3 + 1 self.grid_search_proposer.load(search_space) for ( sharding_options diff --git a/torchrec/distributed/planner/tests/test_shard_estimators.py b/torchrec/distributed/planner/tests/test_shard_estimators.py index 335f08c78..ce52702cb 100644 --- a/torchrec/distributed/planner/tests/test_shard_estimators.py +++ b/torchrec/distributed/planner/tests/test_shard_estimators.py @@ -248,6 +248,49 @@ def test_1_table_perf(self) -> None: bwd_comms=0.004316567291897281, ), ], + # grid_shard is the same as table_row_wise + ("fused", "grid_shard"): [ + Perf( + fwd_compute=6.804365245261984e-05, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.0001360873049052397, + bwd_comms=0.00016798276699240525, + ), + Perf( + fwd_compute=6.804365245261984e-05, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.0001360873049052397, + bwd_comms=0.00016798276699240525, + ), + ], + ("fused_uvm", "grid_shard"): [ + Perf( + fwd_compute=0.011967677696078432, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.023935355392156864, + bwd_comms=0.018426483752680762, + ), + Perf( + fwd_compute=0.011967677696078432, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.023935355392156864, + bwd_comms=0.018426483752680762, + ), + ], + ("fused_uvm_caching", "grid_shard"): [ + Perf( + fwd_compute=0.0027718054609445954, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.005543610921889191, + bwd_comms=0.004316567291897281, + ), + Perf( + fwd_compute=0.0027718054609445954, + fwd_comms=6.357828776041667e-05, + bwd_compute=0.005543610921889191, + bwd_comms=0.004316567291897281, + ), + ], } perfs = { @@ -313,6 +356,13 @@ def test_1_table_perf_with_fp8_comm(self) -> None: 0.007906921553027076, 0.007906921553027076, ], + # grid_shard is the same as table_row_wise + ("fused", "grid_shard"): [0.0002561205605599394, 0.0002561205605599394], + ("fused_uvm", "grid_shard"): [0.03392836626838235, 0.03392836626838235], + ("fused_uvm_caching", "grid_shard"): [ + 0.007906921553027076, + 0.007906921553027076, + ], } total_perfs = { @@ -506,12 +556,17 @@ def cacheability(self) -> float: 0.007304490781297871, 0.007304490781297871, ], + ("table_0", "fused_uvm_caching", "grid_shard"): [ + 0.007304490781297871, + 0.007304490781297871, + ], ("table_0", "fused_uvm_caching", "table_wise"): [0.014608981562595743], ("table_1", "fused", "column_wise"): [0.0], ("table_1", "fused", "row_wise"): [0.0, 0.0], ("table_1", "fused", "table_column_wise"): [0.0], ("table_1", "fused", "table_row_wise"): [0.0, 0.0], ("table_1", "fused", "table_wise"): [0.0], + ("table_1", "fused", "grid_shard"): [0.0, 0.0], } prefetch_computes = { diff --git a/torchrec/distributed/planner/types.py b/torchrec/distributed/planner/types.py index a31efa23e..44d7e092c 100644 --- a/torchrec/distributed/planner/types.py +++ b/torchrec/distributed/planner/types.py @@ -537,6 +537,8 @@ class PartitionByType(Enum): HOST = "host" # Uniform, (ie. fixed layout) UNIFORM = "uniform" + # Partitioning based on multiple hosts + MULTI_HOST = "multi_host" @dataclass diff --git a/torchrec/distributed/sharding_plan.py b/torchrec/distributed/sharding_plan.py index 6eb3468e8..787523805 100644 --- a/torchrec/distributed/sharding_plan.py +++ b/torchrec/distributed/sharding_plan.py @@ -124,12 +124,40 @@ def calculate_shard_sizes_and_offsets( or sharding_type == ShardingType.TABLE_COLUMN_WISE.value ): return _calculate_cw_shard_sizes_and_offsets(columns, rows, col_wise_shard_dim) + elif sharding_type == ShardingType.GRID_SHARD.value: + return _calculate_grid_shard_sizes_and_offsets( + rows, local_world_size, columns, col_wise_shard_dim + ) raise ValueError( f"Unrecognized or unsupported sharding type provided: {sharding_type}" ) +def _calculate_grid_shard_sizes_and_offsets( + hash_size: int, + num_device: int, + columns: int, + col_wise_shard_dim: Optional[int] = None, +) -> Tuple[List[List[int]], List[List[int]]]: + """ + Similar to row-wise case, but also splits columns into blocks of size `col_wise_shard_dim`. + """ + row_shard_sizes, row_shard_offsets = _calculate_rw_shard_sizes_and_offsets( + hash_size, num_device, columns + ) + block_size = _get_block_size_for_cw_shard(columns, col_wise_shard_dim) + num_col_wise_nodes, _residual = divmod(columns, block_size) + shard_sizes: List[List[int]] = [] + shard_offsets: List[List[int]] = [] + + for node in range(num_col_wise_nodes): + for row_shard_size, row_shard_offset in zip(row_shard_sizes, row_shard_offsets): + shard_sizes.append([row_shard_size[0], block_size]) + shard_offsets.append([row_shard_offset[0], block_size * node]) + return shard_sizes, shard_offsets + + def _calculate_rw_shard_sizes_and_offsets( hash_size: int, num_devices: int, columns: int ) -> Tuple[List[List[int]], List[List[int]]]: @@ -201,6 +229,28 @@ def _find_base_dim(lower_bound: int, dim: int) -> int: return dim +def _get_block_size_for_cw_shard( + columns: int, column_wise_shard_dim: Optional[int] +) -> int: + block_size: int = min( + ( + _find_base_dim(column_wise_shard_dim, columns) + if column_wise_shard_dim + else _find_base_dim(MIN_CW_DIM, columns) + ), + columns, + ) + + if columns % block_size != 0: + warnings.warn( + f"Dim of {columns} cannot be evenly divided with column wise shard" + "dim {column_wise_shard_dim}, overriding block_size to embedding_dim={columns}", + UserWarning, + ) + block_size = columns + return block_size + + def _calculate_cw_shard_sizes_and_offsets( columns: int, rows: int, @@ -654,6 +704,58 @@ def _parameter_sharding_generator( return _parameter_sharding_generator +def grid_shard( + host_indexes: List[int], +) -> ParameterShardingGenerator: + """ + Returns a generator of ParameterShardingPlan for `ShardingType::GRID_SHARD` for construct_module_sharding_plan. + + Args: + host_indexes (List[int]): index of hosts (nodes) to do row wise + + Example:: + + ebc = EmbeddingBagCollection(...) + plan = construct_module_sharding_plan( + ebc, + { + "table_4": grid_shard(host_indexes=[1,2]), + }, + ) + """ + + def _parameter_sharding_generator( + param: nn.Parameter, + local_size: int, + world_size: int, + device_type: str, + sharder: ModuleSharder[nn.Module], + ) -> ParameterSharding: + size_and_offsets = _get_parameter_size_offsets( + param, + ShardingType.GRID_SHARD, + local_size, + world_size, + ) + size_offset_ranks = [] + for host_count, host_index in enumerate(host_indexes): + for rank in range(local_size): + (size, offset) = size_and_offsets[host_count * local_size + rank] + rank_offset = host_index * local_size + size_offset_ranks.append((size, offset, rank_offset + rank)) + + return _get_parameter_sharding( + param, + ShardingType.GRID_SHARD.value, + size_offset_ranks, + local_size, + device_type, + sharder, + ) + + return _parameter_sharding_generator + + def apply_to_all( module: nn.Module, parameter_sharding_generator: ParameterShardingGenerator, diff --git a/torchrec/distributed/tests/test_sharding_plan.py b/torchrec/distributed/tests/test_sharding_plan.py index b51453391..d5ba9e774 100644 --- a/torchrec/distributed/tests/test_sharding_plan.py +++ b/torchrec/distributed/tests/test_sharding_plan.py @@ -24,6 +24,7 @@ FeatureProcessedEmbeddingBagCollectionSharder, FusedEmbeddingBagCollectionSharder, get_module_to_default_sharders, + grid_shard, ManagedCollisionEmbeddingBagCollectionSharder, ManagedCollisionEmbeddingCollectionSharder, ParameterShardingGenerator, @@ -240,7 +241,7 @@ def test_construct_module_sharding_plan(self) -> None: embedding_dim=256, num_embeddings=32 * 32, ) - for idx in range(5) + for idx in range(6) ] expected = { @@ -567,6 +568,95 @@ def test_construct_module_sharding_plan(self) -> None: ] ), ), + "table_5": ParameterSharding( + sharding_type="grid_shard", + compute_kernel="dense", + ranks=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + sharding_spec=EnumerableShardingSpec( + shards=[ + ShardMetadata( + shard_offsets=[0, 0], + shard_sizes=[128, 128], + placement="rank:0/cuda:0", + ), + ShardMetadata( + shard_offsets=[128, 0], + shard_sizes=[128, 128], + placement="rank:1/cuda:1", + ), + ShardMetadata( + shard_offsets=[256, 0], + shard_sizes=[128, 128], + placement="rank:2/cuda:2", + ), + ShardMetadata( + shard_offsets=[384, 0], + shard_sizes=[128, 128], + placement="rank:3/cuda:3", + ), + ShardMetadata( + shard_offsets=[512, 0], + shard_sizes=[128, 128], + placement="rank:4/cuda:4", + ), + ShardMetadata( + shard_offsets=[640, 0], + shard_sizes=[128, 128], + placement="rank:5/cuda:5", + ), + ShardMetadata( + shard_offsets=[768, 0], + shard_sizes=[128, 128], + placement="rank:6/cuda:6", + ), + ShardMetadata( + shard_offsets=[896, 0], + shard_sizes=[128, 128], + placement="rank:7/cuda:7", + ), + ShardMetadata( + shard_offsets=[0, 128], + shard_sizes=[128, 128], + placement="rank:8/cuda:0", + ), + ShardMetadata( + shard_offsets=[128, 128], + shard_sizes=[128, 128], + placement="rank:9/cuda:1", + ), + ShardMetadata( + shard_offsets=[256, 128], + shard_sizes=[128, 128], + placement="rank:10/cuda:2", + ), + ShardMetadata( + shard_offsets=[384, 128], + shard_sizes=[128, 128], + placement="rank:11/cuda:3", + ), + ShardMetadata( + shard_offsets=[512, 128], + shard_sizes=[128, 128], + placement="rank:12/cuda:4", + ), + ShardMetadata( + shard_offsets=[640, 128], + shard_sizes=[128, 128], + placement="rank:13/cuda:5", + ), + ShardMetadata( + shard_offsets=[768, 128], + shard_sizes=[128, 128], + placement="rank:14/cuda:6", + ), + ShardMetadata( + shard_offsets=[896, 128], + shard_sizes=[128, 128], + placement="rank:15/cuda:7", + ), + ] + ), + ), } module_sharding_plan = construct_module_sharding_plan( @@ -577,6 +667,7 @@ def test_construct_module_sharding_plan(self) -> None: "table_2": row_wise(), "table_3": column_wise(ranks=[8, 9]), "table_4": table_row_wise(host_index=3), + "table_5": grid_shard(host_indexes=[0, 1]), }, local_size=8, world_size=32, diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index ba36a01f7..3173e710d 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -67,6 +67,14 @@ logger: logging.Logger = logging.getLogger(__name__) +# This is required to support older torch package export for older models +try: + from torchrec.distributed.comm_ops import torchrec_use_sync_collectives +except ImportError: + logger.warning("torchrec_use_sync_collectives is not available") + +torch.ops.import_module("fbgemm_gpu.sparse_ops") + class ModelDetachedException(Exception): pass @@ -1553,6 +1561,8 @@ def __init__( custom_model_fwd, ) + torch._logging.set_logs(compiled_autograd_verbose=True) + # it will check this path on model to inject configuration other than # the default one. self.compiled_autograd_options: Dict[str, Union[str, bool]] = getattr( @@ -1564,8 +1574,6 @@ def __init__( "fullgraph": True, }, ) - - torch._dynamo.config.optimize_ddp = "python_reducer" torch._dynamo.config.inline_inbuilt_nn_modules = True torch._dynamo.config.skip_fsdp_hooks = False torch._functorch.config.recompute_views = True @@ -1593,16 +1601,49 @@ def get_compiled_autograd_ctx( torch.compile(**self.compiled_autograd_options) ) - @contextmanager - def sync_collectives_ctx(self) -> Iterator[None]: - try: - if is_torchdynamo_compiling(): - torchrec.distributed.comm_ops.set_use_sync_collectives(True) - yield - finally: - torchrec.distributed.comm_ops.set_use_sync_collectives(False) - def progress(self, dataloader_iter: Iterator[In]) -> Out: + if not self._model_attached: + self.attach(self._model) + + self.fill_pipeline(dataloader_iter) + if not self.batches: + raise StopIteration + + # TODO: Remove once Bulk Eval migrated (needed for bwd compat, this class only) + self._set_module_context(self.contexts[0]) + + if self._model.training: + with record_function("## zero_grad ##"): + self._optimizer.zero_grad() + + with record_function("## wait_for_batch ##"): + _wait_for_batch(cast(In, self.batches[0]), self._data_dist_stream) + + if len(self.batches) >= 2: + self.start_sparse_data_dist(self.batches[1], self.contexts[1]) + + # batch i+2 + self.enqueue_batch(dataloader_iter) + + # forward + ctx = self.get_compiled_autograd_ctx() + with ctx, torchrec_use_sync_collectives(), record_function("## forward ##"): + losses, output = self._model_fwd(self.batches[0]) + + if len(self.batches) >= 2: + self.wait_sparse_data_dist(self.contexts[1]) - with self.get_compiled_autograd_ctx(), self.sync_collectives_ctx(): - return super().progress(dataloader_iter) + if self._model.training: + # backward + ctx = self.get_compiled_autograd_ctx() + with ctx, torchrec_use_sync_collectives(), record_function( + "## backward ##" + ): + torch.sum(losses, dim=0).backward() + + # update + with record_function("## optimizer ##"): + self._optimizer.step() + + self.dequeue_batch() + return output diff --git a/torchrec/sparse/tests/test_jagged_tensor.py b/torchrec/sparse/tests/test_jagged_tensor.py index b341cd584..782728a81 100644 --- a/torchrec/sparse/tests/test_jagged_tensor.py +++ b/torchrec/sparse/tests/test_jagged_tensor.py @@ -17,6 +17,7 @@ from torch.testing import FileCheck from torchrec.fx import symbolic_trace from torchrec.sparse.jagged_tensor import ( + _fbgemm_permute_pooled_embs, _kt_regroup_arguments, _regroup_keyed_tensors, ComputeJTDictToKJT, @@ -2342,6 +2343,7 @@ def test_regroup_multiple_kt(self) -> None: KeyedTensor.regroup, regroup_kts, permute_multi_embedding, + _fbgemm_permute_pooled_embs, ], device_str=["cpu", "cuda", "meta"], ) @@ -2376,6 +2378,7 @@ def test_regroup_kts( KeyedTensor.regroup, regroup_kts, permute_multi_embedding, + _fbgemm_permute_pooled_embs, ], device_str=["cpu", "cuda", "meta"], ) @@ -2446,18 +2449,33 @@ def test_regroup_backward_skips_and_duplicates(self) -> None: torch.allclose(actual_kt_0_grad, expected_kt_0_grad) torch.allclose(actual_kt_1_grad, expected_kt_1_grad) - def test_regroup_backward(self) -> None: + @repeat_test( + regroup_func=[ + KeyedTensor.regroup, + regroup_kts, + permute_multi_embedding, + _fbgemm_permute_pooled_embs, + ], + device_str=["cpu", "cuda"], + ) + def test_regroup_backward( + self, regroup_func: Callable[..., List[torch.Tensor]], device_str: str + ) -> None: + if device_str == "cuda" and not torch.cuda.is_available(): + return + else: + device = torch.device(device_str) kts = build_kts( dense_features=20, sparse_features=20, dim_dense=64, dim_sparse=128, batch_size=128, - device=torch.device("cpu"), + device=device, run_backward=True, ) groups = build_groups(kts=kts, num_groups=2, skips=False, duplicates=False) - labels = torch.randint(0, 1, (128,), device=torch.device("cpu")).float() + labels = torch.randint(0, 1, (128,), device=device).float() tensor_groups = KeyedTensor.regroup(kts, groups) pred0 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) @@ -2473,7 +2491,7 @@ def test_regroup_backward(self) -> None: kts[0].values().grad = None kts[1].values().grad = None - tensor_groups = _regroup_keyed_tensors(kts, groups) + tensor_groups = regroup_func(kts, groups) pred1 = tensor_groups[0].sum(dim=1).mul(tensor_groups[1].sum(dim=1)) loss = torch.nn.functional.l1_loss(pred1, labels).sum() expected_kt_0_grad = torch.autograd.grad(