Skip to content

Commit

Permalink
2024-09-21 nightly release (7c301de)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Sep 21, 2024
1 parent e239cfd commit a1cff45
Show file tree
Hide file tree
Showing 7 changed files with 291 additions and 59 deletions.
170 changes: 122 additions & 48 deletions torchrec/distributed/quant_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,15 @@
EmbeddingConfig,
)
from torchrec.quant.embedding_modules import (
_get_batching_hinted_output,
EmbeddingCollection as QuantEmbeddingCollection,
MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
)
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
from torchrec.streamable import Multistreamable

torch.fx.wrap("len")
torch.fx.wrap("_get_batching_hinted_output")

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand Down Expand Up @@ -146,18 +148,41 @@ def _fx_trec_unwrap_optional_tensor(optional: Optional[torch.Tensor]) -> torch.T


@torch.fx.wrap
def _fx_trec_wrap_length_tolist(length: torch.Tensor) -> List[int]:
return length.long().tolist()


@torch.fx.wrap
def _get_unbucketize_tensor_via_length_alignment(
lengths: torch.Tensor,
bucketize_length: torch.Tensor,
bucketize_permute_tensor: torch.Tensor,
bucket_mapping_tensor: torch.Tensor,
) -> torch.Tensor:
return bucketize_permute_tensor


@torch.fx.wrap
def _fx_trec_get_feature_length(
features: KeyedJaggedTensor, embedding_names: List[str]
) -> torch.Tensor:
torch._assert(
len(embedding_names) == len(features.keys()),
"embedding output and features mismatch",
)
return features.lengths()


def _construct_jagged_tensors_tw(
embeddings: List[torch.Tensor],
embedding_names_per_rank: List[List[str]],
features: KJTList,
need_indices: bool,
) -> Dict[str, JaggedTensor]:
ret: Dict[str, JaggedTensor] = {}
for i in range(len(embeddings)):
for i in range(len(embedding_names_per_rank)):
embeddings_i: torch.Tensor = embeddings[i]
features_i: KeyedJaggedTensor = features[i]
if features_i.lengths().numel() == 0:
# No table on the rank, skip.
continue

lengths = features_i.lengths().view(-1, features_i.stride())
values = features_i.values()
Expand All @@ -168,48 +193,48 @@ def _construct_jagged_tensors_tw(
lengths_tuple = torch.unbind(lengths.view(-1, stride), dim=0)
if need_indices:
values_list = torch.split(values, length_per_key)
for i, key in enumerate(features_i.keys()):
for j, key in enumerate(embedding_names_per_rank[i]):
ret[key] = JaggedTensor(
lengths=lengths_tuple[i],
values=embeddings_list[i],
weights=values_list[i],
lengths=lengths_tuple[j],
values=embeddings_list[j],
weights=values_list[j],
)
else:
for i, key in enumerate(features_i.keys()):
for j, key in enumerate(embedding_names_per_rank[i]):
ret[key] = JaggedTensor(
lengths=lengths_tuple[i],
values=embeddings_list[i],
lengths=lengths_tuple[j],
values=embeddings_list[j],
weights=None,
)
return ret


@torch.fx.wrap
def _construct_jagged_tensors_rw(
embeddings: List[torch.Tensor],
features_before_input_dist: KeyedJaggedTensor,
feature_keys: List[str],
feature_length: torch.Tensor,
feature_indices: Optional[torch.Tensor],
need_indices: bool,
unbucketize_tensor: torch.Tensor,
) -> Dict[str, JaggedTensor]:
ret: Dict[str, JaggedTensor] = {}
unbucketized_embs = torch.concat(embeddings, dim=0).index_select(
0, unbucketize_tensor
)
embs_split_per_key = unbucketized_embs.split(
features_before_input_dist.length_per_key(), dim=0
)
stride = features_before_input_dist.stride()
lengths_list = torch.unbind(
features_before_input_dist.lengths().view(-1, stride), dim=0
feature_length_2d = feature_length.view(len(feature_keys), -1)
length_per_key: List[int] = _fx_trec_wrap_length_tolist(
torch.sum(feature_length_2d, dim=1)
)
embs_split_per_key = unbucketized_embs.split(length_per_key, dim=0)
lengths_list = torch.unbind(feature_length_2d, dim=0)
values_list: List[torch.Tensor] = []
if need_indices:
# pyre-ignore
values_list = torch.split(
features_before_input_dist.values(),
features_before_input_dist.length_per_key(),
_fx_trec_unwrap_optional_tensor(feature_indices),
length_per_key,
)
for i, key in enumerate(features_before_input_dist.keys()):
for i, key in enumerate(feature_keys):
ret[key] = JaggedTensor(
values=embs_split_per_key[i],
lengths=lengths_list[i],
Expand Down Expand Up @@ -282,10 +307,13 @@ def _construct_jagged_tensors(
sharding_type: str,
embeddings: List[torch.Tensor],
features: KJTList,
embedding_names: List[str],
embedding_names_per_rank: List[List[str]],
features_before_input_dist: KeyedJaggedTensor,
need_indices: bool,
rw_unbucketize_tensor: Optional[torch.Tensor],
rw_bucket_mapping_tensor: Optional[torch.Tensor],
rw_feature_length_after_bucketize: Optional[torch.Tensor],
cw_features_to_permute_indices: Dict[str, torch.Tensor],
key_to_feature_permuted_coordinates: Dict[str, torch.Tensor],
) -> Dict[str, JaggedTensor]:
Expand All @@ -303,11 +331,28 @@ def _construct_jagged_tensors(
raise ValueError("rw_unbucketize_tensor is required for row-wise sharding")

if sharding_type == ShardingType.ROW_WISE.value:
features_before_input_dist_length = _fx_trec_get_feature_length(
features_before_input_dist, embedding_names
)
embeddings = [
_get_batching_hinted_output(
_fx_trec_get_feature_length(features[i], embedding_names_per_rank[i]),
embeddings[i],
)
for i in range(len(embedding_names_per_rank))
]
return _construct_jagged_tensors_rw(
embeddings,
features_before_input_dist,
embedding_names,
features_before_input_dist_length,
features_before_input_dist.values() if need_indices else None,
need_indices,
_fx_trec_unwrap_optional_tensor(rw_unbucketize_tensor),
_get_unbucketize_tensor_via_length_alignment(
features_before_input_dist_length,
rw_feature_length_after_bucketize,
rw_unbucketize_tensor,
rw_bucket_mapping_tensor,
),
)
elif sharding_type == ShardingType.COLUMN_WISE.value:
return _construct_jagged_tensors_cw(
Expand All @@ -319,7 +364,9 @@ def _construct_jagged_tensors(
key_to_feature_permuted_coordinates,
)
else: # sharding_type == ShardingType.TABLE_WISE.value
return _construct_jagged_tensors_tw(embeddings, features, need_indices)
return _construct_jagged_tensors_tw(
embeddings, embedding_names_per_rank, features, need_indices
)


# Wrap the annotation in a separate function with input parameter so that it won't be dropped during symbolic trace.
Expand Down Expand Up @@ -639,6 +686,8 @@ def input_dist(
input_dist_result_list,
features_by_sharding,
unbucketize_permute_tensor_list,
bucket_mapping_tensor_list,
bucketized_length_list,
) = self._input_dist(features)

with torch.no_grad():
Expand All @@ -649,6 +698,8 @@ def input_dist(
features=input_dist_result_list[i],
features_before_input_dist=features_by_sharding[i],
unbucketize_permute_tensor=unbucketize_permute_tensor_list[i],
bucket_mapping_tensor=bucket_mapping_tensor_list[i],
bucketized_length=bucketized_length_list[i],
)
)
return input_dist_result_list
Expand Down Expand Up @@ -680,6 +731,8 @@ def output_dist(
features_before_input_dist_per_sharding: List[KeyedJaggedTensor] = []
features_per_sharding: List[KJTList] = []
unbucketize_tensors: List[Optional[torch.Tensor]] = []
bucket_mapping_tensors: List[Optional[torch.Tensor]] = []
bucketized_lengths: List[Optional[torch.Tensor]] = []
for sharding_output_dist, embeddings, sharding_ctx in zip(
self._output_dists,
output,
Expand All @@ -695,6 +748,16 @@ def output_dist(
if sharding_ctx.unbucketize_permute_tensor is not None
else None
)
bucket_mapping_tensors.append(
sharding_ctx.bucket_mapping_tensor
if sharding_ctx.bucket_mapping_tensor is not None
else None
)
bucketized_lengths.append(
sharding_ctx.bucketized_length
if sharding_ctx.bucketized_length is not None
else None
)
features_before_input_dist_per_sharding.append(
# pyre-ignore
sharding_ctx.features_before_input_dist
Expand All @@ -704,6 +767,8 @@ def output_dist(
features_per_sharding=features_per_sharding,
features_before_input_dist_per_sharding=features_before_input_dist_per_sharding,
unbucketize_tensors=unbucketize_tensors,
bucket_mapping_tensors=bucket_mapping_tensors,
bucketized_lengths=bucketized_lengths,
)

def output_jt_dict(
Expand All @@ -712,10 +777,9 @@ def output_jt_dict(
features_per_sharding: List[KJTList],
features_before_input_dist_per_sharding: List[KeyedJaggedTensor],
unbucketize_tensors: List[Optional[torch.Tensor]],
bucket_mapping_tensors: List[Optional[torch.Tensor]],
bucketized_lengths: List[Optional[torch.Tensor]],
) -> Dict[str, JaggedTensor]:
jt_values: Dict[str, torch.Tensor] = {}
jt_lengths: Dict[str, torch.Tensor] = {}
jt_weights: Dict[str, torch.Tensor] = {}
jt_dict_res: Dict[str, JaggedTensor] = {}
for (
(sharding_type, _),
Expand All @@ -725,6 +789,8 @@ def output_jt_dict(
embedding_names_per_rank,
features_before_input_dist,
unbucketize_tensor,
bucket_mapping_tensor,
bucketized_length,
key_to_feature_permuted_coordinates,
) in zip(
self._sharding_type_device_group_to_sharding.keys(),
Expand All @@ -734,13 +800,15 @@ def output_jt_dict(
self._embedding_names_per_rank_per_sharding,
features_before_input_dist_per_sharding,
unbucketize_tensors,
bucket_mapping_tensors,
bucketized_lengths,
self._key_to_feature_permuted_coordinates_per_sharding,
):

jt_dict = _construct_jagged_tensors(
sharding_type=sharding_type,
embeddings=emb_sharding,
features=features_sharding,
embedding_names=embedding_names,
embedding_names_per_rank=embedding_names_per_rank,
features_before_input_dist=features_before_input_dist,
need_indices=self._need_indices,
Expand All @@ -750,25 +818,21 @@ def output_jt_dict(
if sharding_type == ShardingType.ROW_WISE.value
else None
),
rw_bucket_mapping_tensor=(
_fx_trec_unwrap_optional_tensor(bucket_mapping_tensor)
if sharding_type == ShardingType.ROW_WISE.value
else None
),
rw_feature_length_after_bucketize=(
_fx_trec_unwrap_optional_tensor(bucketized_length)
if sharding_type == ShardingType.ROW_WISE.value
else None
),
cw_features_to_permute_indices=self._features_to_permute_indices,
key_to_feature_permuted_coordinates=key_to_feature_permuted_coordinates,
)
for n in embedding_names:
# For graph split at the end of TW and TW separately,
# we need to unwrap JT to Tensors since JT cannot be par of boundary type.
jt_values[n] = jt_dict[n].values()
jt_lengths[n] = jt_dict[n].lengths()
if self._need_indices:
jt_weights[n] = jt_dict[n].weights()
for n in self._all_embedding_names:
if self._need_indices:
jt_dict_res[n] = JaggedTensor(
values=jt_values[n], lengths=jt_lengths[n], weights=jt_weights[n]
)
else:
jt_dict_res[n] = JaggedTensor(
values=jt_values[n], lengths=jt_lengths[n]
)
for embedding_name in embedding_names:
jt_dict_res[embedding_name] = jt_dict[embedding_name]

return jt_dict_res

Expand Down Expand Up @@ -926,12 +990,18 @@ def __init__(
persistent=False,
)

def forward(
self, features: KeyedJaggedTensor
) -> Tuple[List[KJTList], List[KeyedJaggedTensor], List[Optional[torch.Tensor]]]:
def forward(self, features: KeyedJaggedTensor) -> Tuple[
List[KJTList],
List[KeyedJaggedTensor],
List[Optional[torch.Tensor]],
List[Optional[torch.Tensor]],
List[Optional[torch.Tensor]],
]:
with torch.no_grad():
ret: List[KJTList] = []
unbucketize_permute_tensor = []
bucket_mapping_tensor = []
bucketized_lengths = []
if self._features_order:
features = input_dist_permute(
features,
Expand All @@ -953,9 +1023,13 @@ def forward(
unbucketize_permute_tensor.append(
input_dist_result.unbucketize_permute_tensor
)
bucket_mapping_tensor.append(input_dist_result.bucket_mapping_tensor)
bucketized_lengths.append(input_dist_result.bucketized_length)

return (
ret,
features_by_sharding,
unbucketize_permute_tensor,
bucket_mapping_tensor,
bucketized_lengths,
)
23 changes: 13 additions & 10 deletions torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,15 +161,14 @@ def _unwrap_kjt(


def _unwrap_kjt_for_cpu(
features: KeyedJaggedTensor,
features: KeyedJaggedTensor, weighted: bool
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
indices = features.values()
offsets = features.offsets()
return (
indices,
offsets.type(indices.dtype),
features.weights_or_none(),
)
offsets = features.offsets().type(indices.dtype)
if weighted:
return indices, offsets, features.weights()
else:
return indices, offsets, None


@torch.fx.wrap
Expand Down Expand Up @@ -272,7 +271,9 @@ def get_tbes_to_register(
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
# Important: _unwrap_kjt regex for FX tracing TAGing
if self._runtime_device.type == "cpu":
indices, offsets, per_sample_weights = _unwrap_kjt_for_cpu(features)
indices, offsets, per_sample_weights = _unwrap_kjt_for_cpu(
features, self._config.is_weighted
)
else:
indices, offsets, per_sample_weights = _unwrap_kjt(features)

Expand Down Expand Up @@ -459,8 +460,10 @@ def split_embedding_weights(

def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
if self._runtime_device.type == "cpu":
# To distinguish with QEBC for fx tracing on CPU embedding.
values, offsets, _ = _unwrap_kjt_for_cpu(features)
# To distinguish fx tracing on CPU embedding.
values, offsets, _ = _unwrap_kjt_for_cpu(
features, weighted=self._config.is_weighted
)
else:
values, offsets, _ = _unwrap_kjt(features)

Expand Down
Loading

0 comments on commit a1cff45

Please sign in to comment.