Skip to content

Commit

Permalink
2024-10-11 nightly release (b6e784e)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Oct 11, 2024
1 parent aa7c886 commit efcd28d
Show file tree
Hide file tree
Showing 13 changed files with 548 additions and 44 deletions.
65 changes: 55 additions & 10 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,42 @@ class ShardParams:
local_metadata: List[ShardMetadata]
embedding_weights: List[torch.Tensor]

def get_optimizer_single_value_shard_metadata_and_global_metadata(
table_global_metadata: ShardedTensorMetadata,
optimizer_state: torch.Tensor,
) -> Tuple[Dict[ShardMetadata, ShardMetadata], ShardedTensorMetadata]:
table_global_shards_metadata: List[ShardMetadata] = (
table_global_metadata.shards_metadata
)

table_shard_metadata_to_optimizer_shard_metadata = {}
for offset, table_shard_metadata in enumerate(table_global_shards_metadata):
table_shard_metadata_to_optimizer_shard_metadata[
table_shard_metadata
] = ShardMetadata(
shard_sizes=[1], # single value optimizer state
shard_offsets=[offset], # offset increases by 1 for each shard
placement=table_shard_metadata.placement,
)

tensor_properties = TensorProperties(
dtype=optimizer_state.dtype,
layout=optimizer_state.layout,
requires_grad=False,
)
single_value_optimizer_st_metadata = ShardedTensorMetadata(
shards_metadata=list(
table_shard_metadata_to_optimizer_shard_metadata.values()
),
size=torch.Size([len(table_global_shards_metadata)]),
tensor_properties=tensor_properties,
)

return (
table_shard_metadata_to_optimizer_shard_metadata,
single_value_optimizer_st_metadata,
)

def get_optimizer_rowwise_shard_metadata_and_global_metadata(
table_global_metadata: ShardedTensorMetadata,
optimizer_state: torch.Tensor,
Expand Down Expand Up @@ -356,7 +392,10 @@ def get_optimizer_pointwise_shard_metadata_and_global_metadata(
if optimizer_states:
optimizer_state_values = tuple(optimizer_states.values())
for optimizer_state_value in optimizer_state_values:
assert table_config.local_rows == optimizer_state_value.size(0)
assert (
table_config.local_rows == optimizer_state_value.size(0)
or optimizer_state_value.nelement() == 1 # single value state
)
optimizer_states_keys_by_table[table_config.name] = list(
optimizer_states.keys()
)
Expand Down Expand Up @@ -435,29 +474,35 @@ def get_sharded_optim_state(momentum_idx: int) -> ShardedTensor:
momentum_local_shards: List[Shard] = []
optimizer_sharded_tensor_metadata: ShardedTensorMetadata

is_rowwise_optimizer_state: bool = (
# pyre-ignore
shard_params.optimizer_states[0][momentum_idx - 1].dim()
== 1
)

if is_rowwise_optimizer_state:
optim_state = shard_params.optimizer_states[0][momentum_idx - 1] # pyre-ignore[16]
if optim_state.nelement() == 1:
# single value state: one value per table
(
table_shard_metadata_to_optimizer_shard_metadata,
optimizer_sharded_tensor_metadata,
) = get_optimizer_single_value_shard_metadata_and_global_metadata(
table_config.global_metadata,
optim_state,
)
elif optim_state.dim() == 1:
# rowwise state: param.shape[0] == state.shape[0], state.shape[1] == 1
(
table_shard_metadata_to_optimizer_shard_metadata,
optimizer_sharded_tensor_metadata,
) = get_optimizer_rowwise_shard_metadata_and_global_metadata(
table_config.global_metadata,
shard_params.optimizer_states[0][momentum_idx - 1],
optim_state,
sharding_dim,
is_grid_sharded,
)
else:
# pointwise state: param.shape == state.shape
(
table_shard_metadata_to_optimizer_shard_metadata,
optimizer_sharded_tensor_metadata,
) = get_optimizer_pointwise_shard_metadata_and_global_metadata(
table_config.global_metadata,
shard_params.optimizer_states[0][momentum_idx - 1],
optim_state,
)

for optimizer_state, table_shard_local_metadata in zip(
Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/embedding_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,7 @@ def sharding_types(self, compute_device_type: str) -> List[str]:
types += [
ShardingType.ROW_WISE.value,
ShardingType.TABLE_ROW_WISE.value,
ShardingType.GRID_SHARD.value,
]

return types
Expand Down
4 changes: 3 additions & 1 deletion torchrec/distributed/planner/enumerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,6 @@ def _filter_sharding_types(
filtered_sharding_types = list(
set(constrained_sharding_types) & set(allowed_sharding_types)
)

if not filtered_sharding_types:
logger.warn(
"No available sharding types after applying user provided "
Expand Down Expand Up @@ -380,13 +379,16 @@ def get_partition_by_type(sharding_type: str) -> str:
ShardingType.ROW_WISE.value,
ShardingType.DATA_PARALLEL.value,
}
multi_host_sharding_types = {ShardingType.GRID_SHARD.value}

if sharding_type in device_sharding_types:
return PartitionByType.DEVICE.value
elif sharding_type in host_sharding_types:
return PartitionByType.HOST.value
elif sharding_type in uniform_sharding_types:
return PartitionByType.UNIFORM.value
elif sharding_type in multi_host_sharding_types:
return PartitionByType.MULTI_HOST.value

raise ValueError(
f"Unrecognized or unsupported sharding type provided: {sharding_type}"
Expand Down
149 changes: 145 additions & 4 deletions torchrec/distributed/planner/partitioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import copy
import heapq
import itertools
import logging
from dataclasses import dataclass
from enum import Enum
Expand Down Expand Up @@ -242,6 +243,14 @@ def partition(

for sharding_option_group in sharding_option_groups:
if (
sharding_option_group.sharding_options[0].partition_by
== PartitionByType.MULTI_HOST.value
):
self._multi_hosts_partition(sharding_option_group, _host_level_devices)
# _multi_hosts_partition invalidates minheap_devices, force rebuild before using
minheap_devices = None

elif (
sharding_option_group.sharding_options[0].partition_by
== PartitionByType.HOST.value
):
Expand Down Expand Up @@ -321,6 +330,136 @@ def _device_partition(
minheap_devices.extend(tmp_heap)
heapq.heapify(minheap_devices)

@classmethod
def _multi_hosts_partition(
cls,
sharding_option_group: ShardingOptionGroup,
_host_level_devices: List[List[DeviceHardware]],
) -> None:
"""
Partition shards on multiple hosts. This is a greedy algorithm trying to complete partitioning on multiple hosts (sorted by perf).
First we do columnwise sharding among hosts, then tablewise-rowwise sharding within each host.
There're two cases depends on the number of hosts needed to partition shards.
Case one: `num_host_to_allocate >= len(sorted_host_level_devices)`
We'll try to partition only once. Hosts might be selected multiple times in a circular manner.
E.g, we have 3 hosts and `num_host_to_allocate` = 4. We sort all devices on host level. The devices of hosts [0, 1, 2, 0] will be selected for uniform partitioning.
We'll update device information if success, otherwise raise a `PlannerError`.
Case two: `num_host_to_allocate < len(sorted_host_level_devices)`
We'll try to partition with hosts `[host_index, host_index + num_host_to_allocate]` iteratively with host_index incremented by 1 each time.
1) We sort all devices on host level. Set `host_index` = 0
2) We select hosts`[host_index, host_index + num_host_to_allocate]` if indexes are within range.
3) We do uniform partitioning over all devices of the selected hosts. If we cannot partition, then we increase `host_index` by 1 and go to 2); Otherwise we go to 4)
4) Update device information if success, otherwise raise a `PlannerError`.
Keyword arguments:
sharding_option_group -- grouped sharding options
_host_level_devices -- devices
Example::
sharding_option_group.sharding_options = [
ShardingOption(partition_by="multi_host",
shards=[
Shards(storage=1, perf=1),
Shards(storage=1, perf=1),
Shards(storage=1, perf=1),
Shards(storage=1, perf=1),
]),
]
topology = Topology(world_size=6, local_world_size=2)
# sharding_options[0] will be placed on host 1 and host 2 with the multi_hosts strategy, resulting in
topology.devices[0].perf.total = (1,1)
topology.devices[1].perf.total = (1,1)
topology.devices[2].perf.total = (1,1)
topology.devices[3].perf.total = (1,1)
topology.devices[4].perf.total = (0,0)
topology.devices[5].perf.total = (0,0)
"""
# TODO: for now assume just one option for multi_hosts.
if len(sharding_option_group.sharding_options) != 1:
raise PlannerError(
error_type=PlannerErrorType.PARTITION,
message=f"Unexpected length for sharding options: {len(sharding_option_group.sharding_options)}. Length needs to be 1",
)
num_shards = sharding_option_group.sharding_options[0].num_shards

if _host_level_devices is None:
raise PlannerError(
error_type=PlannerErrorType.PARTITION,
message="host level devices is None",
)

local_world_size = len(_host_level_devices[0])
num_host_to_allocate, remainder = divmod(num_shards, local_world_size)

if remainder > 0:
raise PlannerError(
error_type=PlannerErrorType.PARTITION,
message=f"Grid Sharding is unable to place shards equally over hosts without overlapping. {num_shards=} % {local_world_size=} != 0",
)

sorted_host_level_devices = _sort_devices_by_perf(_host_level_devices)
host_index = 0
all_hosts_used = False
while True:
if num_host_to_allocate >= len(sorted_host_level_devices):
# case one: we need to use all hosts
all_hosts_used = True
devices = []
for i in range(num_host_to_allocate):
devices.extend(
sorted_host_level_devices[i % len(sorted_host_level_devices)]
)
else:
# case two: we can use some hosts
devices = list(
itertools.chain(
*sorted_host_level_devices[
host_index : host_index + num_host_to_allocate
]
)
)
host_index += 1 # shift to next host
host_devices = copy.deepcopy(devices)
success = True
sharding_option = sharding_option_group.sharding_options[0]
try:
if sharding_option.sharding_type == ShardingType.GRID_SHARD.value:
cls._uniform_partition([sharding_option], host_devices)
else:
raise PlannerError(
error_type=PlannerErrorType.PARTITION,
message=f"unexpected multi_host sharding type: {sharding_option.sharding_type}",
)
except PlannerError:
success = False
if success:
# successfully found some hosts and partitioned on these hosts
# need to update the devices
for device, host_device in zip(devices, host_devices):
# check that devices and host_devices are in the same order
if device.rank != host_device.rank:
raise PlannerError(
error_type=PlannerErrorType.PARTITION,
message=f"device rank {device.rank} is not the same as device_copy rank {host_device.rank}",
)
device.storage = host_device.storage
device.perf = host_device.perf
return

if (
host_index + num_host_to_allocate > len(sorted_host_level_devices)
) or all_hosts_used:
break
raise PlannerError(
error_type=PlannerErrorType.PARTITION,
message=f"can't find hosts for sharding option group {sharding_option_group}",
)

@classmethod
def _cohost_partition(
cls,
Expand Down Expand Up @@ -358,8 +497,9 @@ def _cohost_partition(
)
cls._device_partition(sharding_option, minheap_devices)
else:
raise RuntimeError(
f"unexpected cohost sharding type: {sharding_option.sharding_type}"
raise PlannerError(
error_type=PlannerErrorType.PARTITION,
message=f"unexpected cohost sharding type: {sharding_option.sharding_type}",
)
except PlannerError:
success = False
Expand Down Expand Up @@ -395,8 +535,9 @@ def _uniform_partition(
) -> None:
for sharding_option in sharding_options:
if sharding_option.num_shards != len(devices):
raise RuntimeError(
f"For a uniform partition, the number of shards ({sharding_option.num_shards}) must equal the number of devices ({len(devices)})"
raise PlannerError(
error_type=PlannerErrorType.PARTITION,
message=f"For a uniform partition, the number of shards ({sharding_option.num_shards}) must equal the number of devices ({len(devices)})",
)
for i in range(len(devices)):
storage_needed = cast(Storage, sharding_option.shards[i].storage)
Expand Down
14 changes: 9 additions & 5 deletions torchrec/distributed/planner/shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,10 @@ def perf_func_emb_wall_time(
expected_cache_fetches=expected_cache_fetches,
is_inference=is_inference,
)
elif sharding_type == ShardingType.TABLE_ROW_WISE.value:
elif (
sharding_type == ShardingType.TABLE_ROW_WISE.value
or sharding_type == ShardingType.GRID_SHARD.value
):
shard_perf = cls._get_twrw_sharding_perf(
batch_sizes=batch_sizes,
world_size=world_size,
Expand Down Expand Up @@ -729,7 +732,7 @@ def _get_twrw_sharding_perf(
if is_weighted:
bwd_compute = bwd_compute * weighted_feature_bwd_compute_multiplier

# for table-wise-row-wise, expected_cache_fetches per shard is / local_world_size
# for table-wise-row-wise or grid_shard, expected_cache_fetches per shard is / local_world_size
prefetch_compute = cls._get_expected_cache_prefetch_time(
ddr_mem_bw,
expected_cache_fetches / local_world_size,
Expand Down Expand Up @@ -984,7 +987,6 @@ def estimate(
if hasattr(sharder, "fused_params") and sharder.fused_params
else None
)

shard_storages = calculate_shard_storages(
sharder=sharder,
sharding_type=sharding_option.sharding_type,
Expand All @@ -1006,7 +1008,6 @@ def estimate(
is_inference=self._is_inference,
multipass_prefetch_max_pass=mpp_conf.num_passes if mpp_conf else None,
)

for shard, storage in zip(sharding_option.shards, shard_storages):
shard.storage = storage

Expand Down Expand Up @@ -1256,7 +1257,10 @@ def _calculate_shard_io_sizes(
num_poolings=num_poolings,
is_pooled=is_pooled,
)
elif sharding_type == ShardingType.TABLE_ROW_WISE.value:
elif (
sharding_type == ShardingType.TABLE_ROW_WISE.value
or sharding_type == ShardingType.GRID_SHARD.value # same as table row wise
):
return _calculate_twrw_shard_io_sizes(
batch_sizes=batch_sizes,
world_size=world_size,
Expand Down
3 changes: 3 additions & 0 deletions torchrec/distributed/planner/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ def log(
f"{so.tensor.shape[1]} ({so.shards[0].size[1]})"
if so.sharding_type == ShardingType.COLUMN_WISE.value
or so.sharding_type == ShardingType.TABLE_COLUMN_WISE.value
or so.sharding_type == ShardingType.GRID_SHARD.value
else f"{so.tensor.shape[1]}"
)
sharder_cache_load_factor = (
Expand Down Expand Up @@ -875,6 +876,8 @@ def _get_sharding_type_abbr(sharding_type: str) -> str:
return "TWRW"
elif sharding_type == ShardingType.TABLE_COLUMN_WISE.value:
return "TWCW"
elif sharding_type == ShardingType.GRID_SHARD.value:
return "GS"
else:
raise ValueError(
f"Unrecognized or unsupported sharding type provided: {sharding_type}"
Expand Down
Loading

0 comments on commit efcd28d

Please sign in to comment.