Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add QuantManagedCollisionEmbeddingCollection #2351

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchrec/modules/mc_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

def evict(
evictions: Dict[str, Optional[torch.Tensor]],
ebc: Union[EmbeddingBagCollection, EmbeddingCollection],
ebc: nn.Module,
) -> None:
# TODO: write function
return
Expand Down
135 changes: 134 additions & 1 deletion torchrec/quant/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@
from torchrec.modules.fp_embedding_modules import (
FeatureProcessedEmbeddingBagCollection as OriginalFeatureProcessedEmbeddingBagCollection,
)

from torchrec.modules.mc_embedding_modules import (
ManagedCollisionEmbeddingCollection as OriginalManagedCollisionEmbeddingCollection,
)
from torchrec.modules.mc_modules import ManagedCollisionCollection
from torchrec.modules.utils import construct_jagged_tensors_inference
from torchrec.sparse.jagged_tensor import (
ComputeKJTToJTDict,
Expand Down Expand Up @@ -976,3 +979,133 @@ def output_dtype(self) -> torch.dtype:
@property
def device(self) -> torch.device:
return self._device


class QuantManagedCollisionEmbeddingCollection(nn.Module):
"""
QuantManagedCollisionEmbeddingCollection represents a quantized EC module and a set of managed collision modules.
The inputs into the MC-EC/EBC will first be modified by the managed collision module before being passed into the embedding collection.

Args:
tables (List[EmbeddingConfig]): A list of EmbeddingConfig objects representing the embedding tables in the collection.
device (torch.device): The device on which the embedding collection will be allocated.
need_indices (bool, optional): Whether to return the indices along with the embeddings. Defaults to False.
output_dtype (torch.dtype, optional): The data type of the output embeddings. Defaults to torch.float.
table_name_to_quantized_weights (Dict[str, Tuple[Tensor, Tensor]], optional): A dictionary mapping table names to their corresponding quantized weights. Defaults to None.
register_tbes (bool, optional): Whether to register the TBEs in the model. Defaults to False.
quant_state_dict_split_scale_bias (bool, optional): Whether to split the scale and bias parameters when saving the quantized state dict. Defaults to False.
row_alignment (int, optional): The alignment of rows in the quantized weights. Defaults to DEFAULT_ROW_ALIGNMENT.
managed_collision_collection (ManagedCollisionCollection, optional): The managed collision collection to use for managing collisions. Defaults to None.
return_remapped_features (bool, optional): Whether to return the remapped input features in addition to the embeddings. Defaults to False.
"""

def __init__(
self,
tables: List[EmbeddingConfig],
device: torch.device,
need_indices: bool = False,
output_dtype: torch.dtype = torch.float,
table_name_to_quantized_weights: Optional[
Dict[str, Tuple[Tensor, Tensor]]
] = None,
register_tbes: bool = False,
quant_state_dict_split_scale_bias: bool = False,
row_alignment: int = DEFAULT_ROW_ALIGNMENT,
managed_collision_collection: Optional[ManagedCollisionCollection] = None,
return_remapped_features: bool = False,
) -> None:
super().__init__()
assert (
managed_collision_collection
), "Managed collision collection cannot be None"
self._managed_collision_collection: ManagedCollisionCollection = (
managed_collision_collection
)
self._return_remapped_features = return_remapped_features
self._embedding_module = EmbeddingCollection(
tables,
device,
need_indices,
output_dtype,
table_name_to_quantized_weights,
register_tbes,
quant_state_dict_split_scale_bias,
row_alignment,
)

assert str(self._embedding_module.embedding_configs()) == str(
self._managed_collision_collection.embedding_configs()
), "Embedding Collection and Managed Collision Collection must contain the same Embedding Configs"

def forward(
self,
features: KeyedJaggedTensor,
) -> Tuple[
Union[KeyedTensor, Dict[str, JaggedTensor]], Optional[KeyedJaggedTensor]
]:
features = self._managed_collision_collection(features)

embedding_res = self._embedding_module(features)

if not self._return_remapped_features:
return embedding_res, None
return embedding_res, features

def _get_name(self) -> str:
return "QuantManagedCollisionEmbeddingCollection"

@classmethod
def from_float(
cls,
module: OriginalManagedCollisionEmbeddingCollection,
return_remapped_features: bool = False,
) -> "QuantManagedCollisionEmbeddingCollection":
mc_ec = module
ec = module._embedding_module
qconfig = module.qconfig
assert hasattr(
module, "qconfig"
), "QuantManagedCollisionEmbeddingCollection input float module must have qconfig defined"

embedding_configs = copy.deepcopy(ec.embedding_configs())
_update_embedding_configs(
cast(List[BaseEmbeddingConfig], embedding_configs),
qconfig,
)
_update_embedding_configs(
mc_ec._managed_collision_collection._embedding_configs,
qconfig,
)

pruning_dict: Dict[str, torch.Tensor] = getattr(
module, MODULE_ATTR_EMB_CONFIG_NAME_TO_PRUNING_INDICES_REMAPPING_DICT, {}
)

for config in embedding_configs:
if config.name in pruning_dict:
pruning_indices_remapping = pruning_dict[config.name]
config.num_embeddings = pruned_num_embeddings(pruning_indices_remapping)
config.pruning_indices_remapping = pruning_indices_remapping

table_name_to_quantized_weights: Dict[str, Tuple[Tensor, Tensor]] = {}
device = quantize_state_dict(
ec,
table_name_to_quantized_weights,
{table.name: table.data_type for table in embedding_configs},
pruning_dict,
)
return cls(
embedding_configs,
device=device,
output_dtype=qconfig.activation().dtype,
table_name_to_quantized_weights=table_name_to_quantized_weights,
register_tbes=getattr(module, MODULE_ATTR_REGISTER_TBES_BOOL, False),
quant_state_dict_split_scale_bias=getattr(
ec, MODULE_ATTR_QUANT_STATE_DICT_SPLIT_SCALE_BIAS, False
),
row_alignment=getattr(
ec, MODULE_ATTR_ROW_ALIGNMENT_INT, DEFAULT_ROW_ALIGNMENT
),
managed_collision_collection=mc_ec._managed_collision_collection,
return_remapped_features=mc_ec._return_remapped_features,
)
152 changes: 151 additions & 1 deletion torchrec/quant/tests/test_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
# pyre-strict

import unittest
from copy import deepcopy
from dataclasses import replace
from typing import Dict, List, Optional, Type
from typing import cast, Dict, List, Optional, Type

import hypothesis.strategies as st

import torch
from hypothesis import given, settings, Verbosity
from torchrec import inference as trec_infer
Expand All @@ -27,13 +29,23 @@
EmbeddingBagCollection,
EmbeddingCollection,
)
from torchrec.modules.mc_embedding_modules import ManagedCollisionEmbeddingCollection
from torchrec.modules.mc_modules import (
DistanceLFU_EvictionPolicy,
ManagedCollisionCollection,
ManagedCollisionModule,
MCHManagedCollisionModule,
)

from torchrec.quant.embedding_modules import (
_fx_trec_unwrap_kjt,
EmbeddingBagCollection as QuantEmbeddingBagCollection,
EmbeddingCollection as QuantEmbeddingCollection,
quant_prep_enable_quant_state_dict_split_scale_bias,
QuantManagedCollisionEmbeddingCollection,
)
from torchrec.sparse.jagged_tensor import (
ComputeJTDictToKJT,
ComputeKJTToJTDict,
JaggedTensor,
KeyedJaggedTensor,
Expand Down Expand Up @@ -836,3 +848,141 @@ def test_fx_unwrap_unsharded_vs_sharded_in_sync(

self.assertEqual(indices.dtype, sharded_indices.dtype)
self.assertEqual(offsets.dtype, sharded_offsets.dtype)

# pyre-fixme[56]
@given(
data_type=st.sampled_from(
[
DataType.FP32,
DataType.INT8,
]
),
quant_type=st.sampled_from(
[
torch.half,
torch.qint8,
]
),
output_type=st.sampled_from(
[
torch.half,
torch.float,
]
),
device_type=st.sampled_from(
[
"cpu",
"cuda",
]
),
quant_state_dict_split_scale_bias=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=16, deadline=None)
def test_qmcec(
self,
data_type: DataType,
quant_type: torch.dtype,
output_type: torch.dtype,
device_type: str,
quant_state_dict_split_scale_bias: bool,
) -> None:
device = torch.device("cpu" if device_type == "cpu" else "cuda:0")
zch_size = 20
update_interval = 2
update_size = 10

embedding_configs = [
EmbeddingConfig(
name="t1",
embedding_dim=8,
num_embeddings=zch_size,
feature_names=["f1", "f2"],
data_type=data_type,
),
]
ec = EmbeddingCollection(
tables=embedding_configs,
device=device,
)

mc_modules = {
"t1": cast(
ManagedCollisionModule,
MCHManagedCollisionModule(
zch_size=zch_size,
device=device,
eviction_interval=update_interval,
eviction_policy=DistanceLFU_EvictionPolicy(),
),
),
}

for _, value in mc_modules.items():
value.train(False)

mcc_ec = ManagedCollisionCollection(
managed_collision_modules=deepcopy(mc_modules),
# pyre-ignore[6]
embedding_configs=embedding_configs,
)

mc_ec = ManagedCollisionEmbeddingCollection(
ec,
mcc_ec,
return_remapped_features=True,
)

update_one = KeyedJaggedTensor.from_lengths_sync(
keys=["f1", "f2"],
values=torch.concat(
[
torch.arange(1000, 1000 + update_size, dtype=torch.int64),
torch.arange(
1000 + update_size,
1000 + 2 * update_size,
dtype=torch.int64,
),
]
),
lengths=torch.ones((2 * update_size,), dtype=torch.int64),
weights=None,
)

update_one = update_one.to(device)

out1, remapped_kjt1 = mc_ec.forward(update_one)

if quant_state_dict_split_scale_bias:
quant_prep_enable_quant_state_dict_split_scale_bias(ec)

mc_ec.qconfig = torch.quantization.QConfig(
activation=torch.quantization.PlaceholderObserver.with_args(
dtype=output_type
),
weight=torch.quantization.PlaceholderObserver.with_args(dtype=quant_type),
)

qmc_ec = QuantManagedCollisionEmbeddingCollection.from_float(mc_ec)

out2, remapped_kjt2 = qmc_ec.forward(update_one)

self._comp_ec_output(
cast(Dict[str, JaggedTensor], out1), cast(Dict[str, JaggedTensor], out2)
)

from torchrec.fx import symbolic_trace

gm = symbolic_trace(qmc_ec, leaf_modules=[ComputeJTDictToKJT.__name__])
out3, remapped_kjt3 = gm.forward(update_one)

self._comp_ec_output(
cast(Dict[str, JaggedTensor], out1),
cast(Dict[str, JaggedTensor], out3),
)
scripted_module = torch.jit.script(gm)
out4, remapped_kjt4 = scripted_module(update_one)
self._comp_ec_output(
cast(Dict[str, JaggedTensor], out3),
cast(Dict[str, JaggedTensor], out4),
atol=0,
)