diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 792378647..a8da499f5 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -259,6 +259,8 @@ def _fx_trec_unwrap_kjt( class EmbeddingBagCollection(EmbeddingBagCollectionInterface, ModuleNoCopyMixin): + _emb_modules: List[IntNBitTableBatchedEmbeddingBagsCodegen] + _key_to_tables: Dict[Tuple[PoolingType, DataType], List[EmbeddingBagConfig]] """ EmbeddingBagCollection represents a collection of pooled embeddings (EmbeddingBags). This EmbeddingBagCollection is quantized for lower precision. It relies on fbgemm quantized ops and provides @@ -342,13 +344,17 @@ def __init__( self._embedding_bag_configs: List[EmbeddingBagConfig] = tables self._key_to_tables: Dict[ Tuple[PoolingType, DataType], List[EmbeddingBagConfig] - ] = defaultdict(list) + ] = torch.jit.annotate( + Dict[Tuple[PoolingType, DataType], List[EmbeddingBagConfig]], {} + ) self._feature_names: List[str] = [] self._feature_splits: List[int] = [] self._length_per_key: List[int] = [] # Registering in a List instead of ModuleList because we want don't want them to be auto-registered. # Their states will be modified via self.embedding_bags - self._emb_modules: List[nn.Module] = [] + self._emb_modules: List[IntNBitTableBatchedEmbeddingBagsCodegen] = ( + torch.jit.annotate(List[IntNBitTableBatchedEmbeddingBagsCodegen], []) + ) self._output_dtype = output_dtype self._device: torch.device = device self._table_name_to_quantized_weights: Optional[ @@ -363,7 +369,10 @@ def __init__( raise ValueError(f"Duplicate table name {table.name}") table_names.add(table.name) key = (table.pooling, table.data_type) - self._key_to_tables[key].append(table) + if key not in self._key_to_tables: + self._key_to_tables[key] = [table] + else: + self._key_to_tables[key].append(table) location = ( EmbeddingLocation.HOST if device.type == "cpu" else EmbeddingLocation.DEVICE