diff --git a/torchrec/distributed/fp_embeddingbag.py b/torchrec/distributed/fp_embeddingbag.py index 98fc3b22d..872fa6aa6 100644 --- a/torchrec/distributed/fp_embeddingbag.py +++ b/torchrec/distributed/fp_embeddingbag.py @@ -172,7 +172,10 @@ def __init__( ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) self._ebc_sharder: EmbeddingBagCollectionSharder = ( - ebc_sharder or EmbeddingBagCollectionSharder(self.qcomm_codecs_registry) + ebc_sharder + or EmbeddingBagCollectionSharder( + qcomm_codecs_registry=self.qcomm_codecs_registry + ) ) def shard( diff --git a/torchrec/distributed/itep_embeddingbag.py b/torchrec/distributed/itep_embeddingbag.py index daf5b99e8..d8daa4bb3 100644 --- a/torchrec/distributed/itep_embeddingbag.py +++ b/torchrec/distributed/itep_embeddingbag.py @@ -159,7 +159,10 @@ def __init__( ) -> None: super().__init__(qcomm_codecs_registry=qcomm_codecs_registry) self._ebc_sharder: EmbeddingBagCollectionSharder = ( - ebc_sharder or EmbeddingBagCollectionSharder(self.qcomm_codecs_registry) + ebc_sharder + or EmbeddingBagCollectionSharder( + qcomm_codecs_registry=self.qcomm_codecs_registry + ) ) def shard( diff --git a/torchrec/metrics/metric_module.py b/torchrec/metrics/metric_module.py index 933ec313d..f757d4ad7 100644 --- a/torchrec/metrics/metric_module.py +++ b/torchrec/metrics/metric_module.py @@ -43,6 +43,7 @@ from torchrec.metrics.multiclass_recall import MulticlassRecallMetric from torchrec.metrics.ndcg import NDCGMetric from torchrec.metrics.ne import NEMetric +from torchrec.metrics.ne_positive import NEPositiveMetric from torchrec.metrics.output import OutputMetric from torchrec.metrics.precision import PrecisionMetric from torchrec.metrics.rauc import RAUCMetric @@ -64,6 +65,7 @@ REC_METRICS_MAPPING: Dict[RecMetricEnumBase, Type[RecMetric]] = { RecMetricEnum.NE: NEMetric, + RecMetricEnum.NE_POSITIVE: NEPositiveMetric, RecMetricEnum.SEGMENTED_NE: SegmentedNEMetric, RecMetricEnum.CTR: CTRMetric, RecMetricEnum.CALIBRATION: CalibrationMetric, diff --git a/torchrec/metrics/metrics_config.py b/torchrec/metrics/metrics_config.py index 5b0d179c5..6875f2907 100644 --- a/torchrec/metrics/metrics_config.py +++ b/torchrec/metrics/metrics_config.py @@ -20,6 +20,7 @@ class RecMetricEnumBase(StrValueMixin, Enum): class RecMetricEnum(RecMetricEnumBase): NE = "ne" + NE_POSITIVE = "ne_positive" SEGMENTED_NE = "segmented_ne" LOG_LOSS = "log_loss" CTR = "ctr" diff --git a/torchrec/metrics/metrics_namespace.py b/torchrec/metrics/metrics_namespace.py index 6fff895bd..20e257d6d 100644 --- a/torchrec/metrics/metrics_namespace.py +++ b/torchrec/metrics/metrics_namespace.py @@ -40,6 +40,7 @@ class MetricName(MetricNameBase): DEFAULT = "" NE = "ne" + NE_POSITIVE = "ne_positive" SEGMENTED_NE = "segmented_ne" LOG_LOSS = "logloss" THROUGHPUT = "throughput" @@ -83,6 +84,7 @@ class MetricNamespace(MetricNamespaceBase): DEFAULT = "" NE = "ne" + NE_POSITIVE = "ne_positive" SEGMENTED_NE = "segmented_ne" THROUGHPUT = "throughput" CTR = "ctr" diff --git a/torchrec/metrics/ne_positive.py b/torchrec/metrics/ne_positive.py new file mode 100644 index 000000000..2d2147f3d --- /dev/null +++ b/torchrec/metrics/ne_positive.py @@ -0,0 +1,185 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Type + +import torch +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricComputation, + RecMetricException, +) + + +def compute_cross_entropy_positive( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + eta: float, +) -> torch.Tensor: + predictions = predictions.double() + predictions.clamp_(min=eta, max=1 - eta) + cross_entropy_positive = -weights * labels * torch.log2(predictions) + return cross_entropy_positive + + +def _compute_cross_entropy_norm( + mean_label: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + eta: float, +) -> torch.Tensor: + mean_label = mean_label.double() + mean_label.clamp_(min=eta, max=1 - eta) + return -pos_labels * torch.log2(mean_label) - neg_labels * torch.log2( + 1.0 - mean_label + ) + + +@torch.fx.wrap +def compute_ne_positive( + ce_positive_sum: torch.Tensor, + weighted_num_samples: torch.Tensor, + pos_labels: torch.Tensor, + neg_labels: torch.Tensor, + eta: float, + allow_missing_label_with_zero_weight: bool = False, +) -> torch.Tensor: + if allow_missing_label_with_zero_weight and not weighted_num_samples.all(): + # If nan were to occur, return a dummy value instead of nan if + # allow_missing_label_with_zero_weight is True + return torch.tensor([eta]) + + # Goes into this block if all elements in weighted_num_samples > 0 + weighted_num_samples = weighted_num_samples.double().clamp(min=eta) + mean_label = pos_labels / weighted_num_samples + ce_norm = _compute_cross_entropy_norm(mean_label, pos_labels, neg_labels, eta) + return ce_positive_sum / ce_norm + + +def get_ne_positive_states( + labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor, eta: float +) -> Dict[str, torch.Tensor]: + cross_entropy_positive = compute_cross_entropy_positive( + labels, + predictions, + weights, + eta, + ) + return { + "cross_entropy_positive_sum": torch.sum(cross_entropy_positive, dim=-1), + "weighted_num_samples": torch.sum(weights, dim=-1), + "pos_labels": torch.sum(weights * labels, dim=-1), + "neg_labels": torch.sum(weights * (1.0 - labels), dim=-1), + } + + +class NEPositiveMetricComputation(RecMetricComputation): + r""" + This class implements the RecMetricComputation for NE positive, i.e. Normalized Entropy where label = 1 + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + """ + + def __init__( + self, + *args: Any, + allow_missing_label_with_zero_weight: bool = False, + **kwargs: Any, + ) -> None: + self._allow_missing_label_with_zero_weight: bool = ( + allow_missing_label_with_zero_weight + ) + super().__init__(*args, **kwargs) + self._add_state( + "cross_entropy_positive_sum", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "weighted_num_samples", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "pos_labels", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self._add_state( + "neg_labels", + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + self.eta = 1e-12 + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None or weights is None: + raise RecMetricException( + "Inputs 'predictions' and 'weights' should not be None for NEMetricComputation update" + ) + states = get_ne_positive_states(labels, predictions, weights, self.eta) + num_samples = predictions.shape[-1] + + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + reports = [ + MetricComputationReport( + name=MetricName.NE_POSITIVE, + metric_prefix=MetricPrefix.LIFETIME, + value=compute_ne_positive( + cast(torch.Tensor, self.cross_entropy_positive_sum), + cast(torch.Tensor, self.weighted_num_samples), + cast(torch.Tensor, self.pos_labels), + cast(torch.Tensor, self.neg_labels), + self.eta, + self._allow_missing_label_with_zero_weight, + ), + ), + MetricComputationReport( + name=MetricName.NE_POSITIVE, + metric_prefix=MetricPrefix.WINDOW, + value=compute_ne_positive( + self.get_window_state("cross_entropy_positive_sum"), + self.get_window_state("weighted_num_samples"), + self.get_window_state("pos_labels"), + self.get_window_state("neg_labels"), + self.eta, + self._allow_missing_label_with_zero_weight, + ), + ), + ] + return reports + + +class NEPositiveMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.NE_POSITIVE + _computation_class: Type[RecMetricComputation] = NEPositiveMetricComputation diff --git a/torchrec/metrics/tests/test_ne_positive.py b/torchrec/metrics/tests/test_ne_positive.py new file mode 100644 index 000000000..427d73124 --- /dev/null +++ b/torchrec/metrics/tests/test_ne_positive.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict + +import torch +from torchrec.metrics.metrics_config import DefaultTaskInfo +from torchrec.metrics.ne_positive import NEPositiveMetric + + +WORLD_SIZE = 4 +BATCH_SIZE = 10 + + +def generate_model_output() -> Dict[str, torch._tensor.Tensor]: + return { + "predictions": torch.tensor([[0.8, 0.2, 0.3, 0.6, 0.5]]), + "labels": torch.tensor([[1, 0, 0, 1, 1]]), + "weights": torch.tensor([[1, 2, 1, 2, 1]]), + "expected_ne_positive": torch.tensor([0.4054]), + } + + +class NEPositiveValueTest(unittest.TestCase): + r"""This set of tests verify the computation logic of AUC in several + corner cases that we know the computation results. The goal is to + provide some confidence of the correctness of the math formula. + """ + + def setUp(self) -> None: + self.ne_positive = NEPositiveMetric( + world_size=WORLD_SIZE, + my_rank=0, + batch_size=BATCH_SIZE, + tasks=[DefaultTaskInfo], + ) + + def test_ne_positive(self) -> None: + model_output = generate_model_output() + self.ne_positive.update( + predictions={DefaultTaskInfo.name: model_output["predictions"][0]}, + labels={DefaultTaskInfo.name: model_output["labels"][0]}, + weights={DefaultTaskInfo.name: model_output["weights"][0]}, + ) + metric = self.ne_positive.compute() + print(metric) + actual_metric = metric[ + f"ne_positive-{DefaultTaskInfo.name}|lifetime_ne_positive" + ] + expected_metric = model_output["expected_ne_positive"] + + torch.testing.assert_close( + actual_metric, + expected_metric, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + equal_nan=True, + msg=f"Actual: {actual_metric}, Expected: {expected_metric}", + )