Skip to content

Commit

Permalink
add GPU sync and numerical value tests (pytorch#2194)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2194

Added GPU sync tests to simulate gathering metric states on to rank 0 and computing. Tests don't cover this case before, which has resulted in SEVs in the past as users aren't aware of how RecMetrics collects and computes metrics.

Added numerical value tests, most metrics are do not have this which can result in issues down the line if metrics need to be changed/accomodate future changes. Also we've found inconsistencies sometimes from other methods, so always good to check here. We compare each metric to a reference implementation from literature to ensure the values are as expected.

Reviewed By: henrylhtsang

Differential Revision: D59173140
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Sep 25, 2024
1 parent 0d0feb1 commit 7ecf930
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 0 deletions.
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
RecTaskInfo,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -251,3 +253,24 @@ def test_accuracy(self) -> None:
except AssertionError:
print("Assertion error caught with data set ", inputs)
raise


class AccuracyGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = AccuracyMetric
task_name: str = "accuracy"

def test_sync_ne(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=AccuracyMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestAccuracyMetric,
metric_name=AccuracyGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_auprc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
)
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -346,3 +348,24 @@ def test_required_input_for_grouped_auprc(self) -> None:
)

self.assertIn("grouping_keys", auprc.get_required_inputs())


class AUPRCGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = AUPRCMetric
task_name: str = "auprc"

def test_sync_auprc(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=AUPRCMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestAUPRCMetric,
metric_name=AUPRCGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -77,3 +79,24 @@ def test_fused_calibration(self) -> None:
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
)


class CalibrationGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = CalibrationMetric
task_name: str = "calibration"

def test_sync_ne(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=CalibrationMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestCalibrationMetric,
metric_name=CalibrationGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_ctr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -71,3 +73,24 @@ def test_fused_ctr(self) -> None:
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
)


class CTRGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = CTRMetric
task_name: str = "ctr"

def test_sync_ne(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=CTRMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestCTRMetric,
metric_name=CTRGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -74,3 +76,24 @@ def test_fused_mae(self) -> None:
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
)


class MAEGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = MAEMetric
task_name: str = "mae"

def test_sync_ne(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=MAEMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestMAEMetric,
metric_name=MAEGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_ne.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -258,3 +260,24 @@ def test_logloss_update_fused(self) -> None:
entry_point=self._logloss_metric_test_helper,
batch_window_size=10,
)


class NEGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = NEMetric
task_name: str = "ne"

def test_sync_ne(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=NEMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestNEMetric,
metric_name=NEGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)

0 comments on commit 7ecf930

Please sign in to comment.