From f68aacbf47d62f928460d5db0ec9bade577199c1 Mon Sep 17 00:00:00 2001 From: Ze Sheng Date: Tue, 27 Aug 2024 10:39:42 -0700 Subject: [PATCH] Unify arg names to align across kernels & acc & aten ops (Frontend Diff) Summary: As title. Differential Revision: D61856064 --- torchrec/distributed/quant_embedding_kernel.py | 8 ++++---- torchrec/quant/embedding_modules.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py index 947986a1e..748b0f71f 100644 --- a/torchrec/distributed/quant_embedding_kernel.py +++ b/torchrec/distributed/quant_embedding_kernel.py @@ -214,11 +214,11 @@ def __init__( fused_params ) - index_remapping = [ + index_remappings = [ table.pruning_indices_remapping for table in config.embedding_tables ] - if all(v is None for v in index_remapping): - index_remapping = None + if all(v is None for v in index_remappings): + index_remappings = None self._runtime_device: torch.device = _get_runtime_device(device, config) # 16 for CUDA, 1 for others like CPU and MTIA. self._tbe_row_alignment: int = 16 if self._runtime_device.type == "cuda" else 1 @@ -245,7 +245,7 @@ def __init__( ], device=device, # pyre-ignore - index_remapping=index_remapping, + index_remappings=index_remappings, pooling_mode=self._pooling, feature_table_map=self._feature_table_map, row_alignment=self._tbe_row_alignment, diff --git a/torchrec/quant/embedding_modules.py b/torchrec/quant/embedding_modules.py index 792378647..5e2edbae2 100644 --- a/torchrec/quant/embedding_modules.py +++ b/torchrec/quant/embedding_modules.py @@ -407,7 +407,7 @@ def __init__( row_alignment=row_alignment, feature_table_map=feature_table_map, # pyre-ignore - index_remapping=( + index_remappings=( index_remappings if index_remappings_non_none_count > 0 else None ), )