Skip to content

Commit

Permalink
2024-10-05 nightly release (6ffc71a)
Browse files Browse the repository at this point in the history
  • Loading branch information
pytorchbot committed Oct 5, 2024
1 parent 99ac883 commit bb00000
Show file tree
Hide file tree
Showing 7 changed files with 265 additions and 2 deletions.
5 changes: 4 additions & 1 deletion torchrec/distributed/fp_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,10 @@ def __init__(
) -> None:
super().__init__(qcomm_codecs_registry=qcomm_codecs_registry)
self._ebc_sharder: EmbeddingBagCollectionSharder = (
ebc_sharder or EmbeddingBagCollectionSharder(self.qcomm_codecs_registry)
ebc_sharder
or EmbeddingBagCollectionSharder(
qcomm_codecs_registry=self.qcomm_codecs_registry
)
)

def shard(
Expand Down
5 changes: 4 additions & 1 deletion torchrec/distributed/itep_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,10 @@ def __init__(
) -> None:
super().__init__(qcomm_codecs_registry=qcomm_codecs_registry)
self._ebc_sharder: EmbeddingBagCollectionSharder = (
ebc_sharder or EmbeddingBagCollectionSharder(self.qcomm_codecs_registry)
ebc_sharder
or EmbeddingBagCollectionSharder(
qcomm_codecs_registry=self.qcomm_codecs_registry
)
)

def shard(
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from torchrec.metrics.multiclass_recall import MulticlassRecallMetric
from torchrec.metrics.ndcg import NDCGMetric
from torchrec.metrics.ne import NEMetric
from torchrec.metrics.ne_positive import NEPositiveMetric
from torchrec.metrics.output import OutputMetric
from torchrec.metrics.precision import PrecisionMetric
from torchrec.metrics.rauc import RAUCMetric
Expand All @@ -64,6 +65,7 @@

REC_METRICS_MAPPING: Dict[RecMetricEnumBase, Type[RecMetric]] = {
RecMetricEnum.NE: NEMetric,
RecMetricEnum.NE_POSITIVE: NEPositiveMetric,
RecMetricEnum.SEGMENTED_NE: SegmentedNEMetric,
RecMetricEnum.CTR: CTRMetric,
RecMetricEnum.CALIBRATION: CalibrationMetric,
Expand Down
1 change: 1 addition & 0 deletions torchrec/metrics/metrics_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class RecMetricEnumBase(StrValueMixin, Enum):

class RecMetricEnum(RecMetricEnumBase):
NE = "ne"
NE_POSITIVE = "ne_positive"
SEGMENTED_NE = "segmented_ne"
LOG_LOSS = "log_loss"
CTR = "ctr"
Expand Down
2 changes: 2 additions & 0 deletions torchrec/metrics/metrics_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class MetricName(MetricNameBase):
DEFAULT = ""

NE = "ne"
NE_POSITIVE = "ne_positive"
SEGMENTED_NE = "segmented_ne"
LOG_LOSS = "logloss"
THROUGHPUT = "throughput"
Expand Down Expand Up @@ -83,6 +84,7 @@ class MetricNamespace(MetricNamespaceBase):
DEFAULT = ""

NE = "ne"
NE_POSITIVE = "ne_positive"
SEGMENTED_NE = "segmented_ne"
THROUGHPUT = "throughput"
CTR = "ctr"
Expand Down
185 changes: 185 additions & 0 deletions torchrec/metrics/ne_positive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any, cast, Dict, List, Optional, Type

import torch
from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
from torchrec.metrics.rec_metric import (
MetricComputationReport,
RecMetric,
RecMetricComputation,
RecMetricException,
)


def compute_cross_entropy_positive(
labels: torch.Tensor,
predictions: torch.Tensor,
weights: torch.Tensor,
eta: float,
) -> torch.Tensor:
predictions = predictions.double()
predictions.clamp_(min=eta, max=1 - eta)
cross_entropy_positive = -weights * labels * torch.log2(predictions)
return cross_entropy_positive


def _compute_cross_entropy_norm(
mean_label: torch.Tensor,
pos_labels: torch.Tensor,
neg_labels: torch.Tensor,
eta: float,
) -> torch.Tensor:
mean_label = mean_label.double()
mean_label.clamp_(min=eta, max=1 - eta)
return -pos_labels * torch.log2(mean_label) - neg_labels * torch.log2(
1.0 - mean_label
)


@torch.fx.wrap
def compute_ne_positive(
ce_positive_sum: torch.Tensor,
weighted_num_samples: torch.Tensor,
pos_labels: torch.Tensor,
neg_labels: torch.Tensor,
eta: float,
allow_missing_label_with_zero_weight: bool = False,
) -> torch.Tensor:
if allow_missing_label_with_zero_weight and not weighted_num_samples.all():
# If nan were to occur, return a dummy value instead of nan if
# allow_missing_label_with_zero_weight is True
return torch.tensor([eta])

# Goes into this block if all elements in weighted_num_samples > 0
weighted_num_samples = weighted_num_samples.double().clamp(min=eta)
mean_label = pos_labels / weighted_num_samples
ce_norm = _compute_cross_entropy_norm(mean_label, pos_labels, neg_labels, eta)
return ce_positive_sum / ce_norm


def get_ne_positive_states(
labels: torch.Tensor, predictions: torch.Tensor, weights: torch.Tensor, eta: float
) -> Dict[str, torch.Tensor]:
cross_entropy_positive = compute_cross_entropy_positive(
labels,
predictions,
weights,
eta,
)
return {
"cross_entropy_positive_sum": torch.sum(cross_entropy_positive, dim=-1),
"weighted_num_samples": torch.sum(weights, dim=-1),
"pos_labels": torch.sum(weights * labels, dim=-1),
"neg_labels": torch.sum(weights * (1.0 - labels), dim=-1),
}


class NEPositiveMetricComputation(RecMetricComputation):
r"""
This class implements the RecMetricComputation for NE positive, i.e. Normalized Entropy where label = 1
The constructor arguments are defined in RecMetricComputation.
See the docstring of RecMetricComputation for more detail.
"""

def __init__(
self,
*args: Any,
allow_missing_label_with_zero_weight: bool = False,
**kwargs: Any,
) -> None:
self._allow_missing_label_with_zero_weight: bool = (
allow_missing_label_with_zero_weight
)
super().__init__(*args, **kwargs)
self._add_state(
"cross_entropy_positive_sum",
torch.zeros(self._n_tasks, dtype=torch.double),
add_window_state=True,
dist_reduce_fx="sum",
persistent=True,
)
self._add_state(
"weighted_num_samples",
torch.zeros(self._n_tasks, dtype=torch.double),
add_window_state=True,
dist_reduce_fx="sum",
persistent=True,
)
self._add_state(
"pos_labels",
torch.zeros(self._n_tasks, dtype=torch.double),
add_window_state=True,
dist_reduce_fx="sum",
persistent=True,
)
self._add_state(
"neg_labels",
torch.zeros(self._n_tasks, dtype=torch.double),
add_window_state=True,
dist_reduce_fx="sum",
persistent=True,
)
self.eta = 1e-12

def update(
self,
*,
predictions: Optional[torch.Tensor],
labels: torch.Tensor,
weights: Optional[torch.Tensor],
**kwargs: Dict[str, Any],
) -> None:
if predictions is None or weights is None:
raise RecMetricException(
"Inputs 'predictions' and 'weights' should not be None for NEMetricComputation update"
)
states = get_ne_positive_states(labels, predictions, weights, self.eta)
num_samples = predictions.shape[-1]

for state_name, state_value in states.items():
state = getattr(self, state_name)
state += state_value
self._aggregate_window_state(state_name, state_value, num_samples)

def _compute(self) -> List[MetricComputationReport]:
reports = [
MetricComputationReport(
name=MetricName.NE_POSITIVE,
metric_prefix=MetricPrefix.LIFETIME,
value=compute_ne_positive(
cast(torch.Tensor, self.cross_entropy_positive_sum),
cast(torch.Tensor, self.weighted_num_samples),
cast(torch.Tensor, self.pos_labels),
cast(torch.Tensor, self.neg_labels),
self.eta,
self._allow_missing_label_with_zero_weight,
),
),
MetricComputationReport(
name=MetricName.NE_POSITIVE,
metric_prefix=MetricPrefix.WINDOW,
value=compute_ne_positive(
self.get_window_state("cross_entropy_positive_sum"),
self.get_window_state("weighted_num_samples"),
self.get_window_state("pos_labels"),
self.get_window_state("neg_labels"),
self.eta,
self._allow_missing_label_with_zero_weight,
),
),
]
return reports


class NEPositiveMetric(RecMetric):
_namespace: MetricNamespace = MetricNamespace.NE_POSITIVE
_computation_class: Type[RecMetricComputation] = NEPositiveMetricComputation
67 changes: 67 additions & 0 deletions torchrec/metrics/tests/test_ne_positive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import unittest
from typing import Dict

import torch
from torchrec.metrics.metrics_config import DefaultTaskInfo
from torchrec.metrics.ne_positive import NEPositiveMetric


WORLD_SIZE = 4
BATCH_SIZE = 10


def generate_model_output() -> Dict[str, torch._tensor.Tensor]:
return {
"predictions": torch.tensor([[0.8, 0.2, 0.3, 0.6, 0.5]]),
"labels": torch.tensor([[1, 0, 0, 1, 1]]),
"weights": torch.tensor([[1, 2, 1, 2, 1]]),
"expected_ne_positive": torch.tensor([0.4054]),
}


class NEPositiveValueTest(unittest.TestCase):
r"""This set of tests verify the computation logic of AUC in several
corner cases that we know the computation results. The goal is to
provide some confidence of the correctness of the math formula.
"""

def setUp(self) -> None:
self.ne_positive = NEPositiveMetric(
world_size=WORLD_SIZE,
my_rank=0,
batch_size=BATCH_SIZE,
tasks=[DefaultTaskInfo],
)

def test_ne_positive(self) -> None:
model_output = generate_model_output()
self.ne_positive.update(
predictions={DefaultTaskInfo.name: model_output["predictions"][0]},
labels={DefaultTaskInfo.name: model_output["labels"][0]},
weights={DefaultTaskInfo.name: model_output["weights"][0]},
)
metric = self.ne_positive.compute()
print(metric)
actual_metric = metric[
f"ne_positive-{DefaultTaskInfo.name}|lifetime_ne_positive"
]
expected_metric = model_output["expected_ne_positive"]

torch.testing.assert_close(
actual_metric,
expected_metric,
atol=1e-4,
rtol=1e-4,
check_dtype=False,
equal_nan=True,
msg=f"Actual: {actual_metric}, Expected: {expected_metric}",
)

0 comments on commit bb00000

Please sign in to comment.