diff --git a/torchrec/metrics/tests/test_accuracy.py b/torchrec/metrics/tests/test_accuracy.py index 5f9e47416..84d487db1 100644 --- a/torchrec/metrics/tests/test_accuracy.py +++ b/torchrec/metrics/tests/test_accuracy.py @@ -17,8 +17,10 @@ from torchrec.metrics.rec_metric import RecComputeMode, RecMetric from torchrec.metrics.test_utils import ( metric_test_helper, + rec_metric_gpu_sync_test_launcher, rec_metric_value_test_launcher, RecTaskInfo, + sync_test_helper, TestMetric, ) @@ -251,3 +253,24 @@ def test_accuracy(self) -> None: except AssertionError: print("Assertion error caught with data set ", inputs) raise + + +class AccuracyGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = AccuracyMetric + task_name: str = "accuracy" + + def test_sync_ne(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=AccuracyMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestAccuracyMetric, + metric_name=AccuracyGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) diff --git a/torchrec/metrics/tests/test_auprc.py b/torchrec/metrics/tests/test_auprc.py index dacdbab47..95256aa7c 100644 --- a/torchrec/metrics/tests/test_auprc.py +++ b/torchrec/metrics/tests/test_auprc.py @@ -23,7 +23,9 @@ ) from torchrec.metrics.test_utils import ( metric_test_helper, + rec_metric_gpu_sync_test_launcher, rec_metric_value_test_launcher, + sync_test_helper, TestMetric, ) @@ -346,3 +348,24 @@ def test_required_input_for_grouped_auprc(self) -> None: ) self.assertIn("grouping_keys", auprc.get_required_inputs()) + + +class AUPRCGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = AUPRCMetric + task_name: str = "auprc" + + def test_sync_auprc(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=AUPRCMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestAUPRCMetric, + metric_name=AUPRCGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) diff --git a/torchrec/metrics/tests/test_calibration.py b/torchrec/metrics/tests/test_calibration.py index fb6f109f7..2ea49026e 100644 --- a/torchrec/metrics/tests/test_calibration.py +++ b/torchrec/metrics/tests/test_calibration.py @@ -15,7 +15,9 @@ from torchrec.metrics.rec_metric import RecComputeMode, RecMetric from torchrec.metrics.test_utils import ( metric_test_helper, + rec_metric_gpu_sync_test_launcher, rec_metric_value_test_launcher, + sync_test_helper, TestMetric, ) @@ -77,3 +79,24 @@ def test_fused_calibration(self) -> None: world_size=WORLD_SIZE, entry_point=metric_test_helper, ) + + +class CalibrationGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = CalibrationMetric + task_name: str = "calibration" + + def test_sync_ne(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=CalibrationMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestCalibrationMetric, + metric_name=CalibrationGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) diff --git a/torchrec/metrics/tests/test_ctr.py b/torchrec/metrics/tests/test_ctr.py index b3397dcc1..61ca6081e 100644 --- a/torchrec/metrics/tests/test_ctr.py +++ b/torchrec/metrics/tests/test_ctr.py @@ -15,7 +15,9 @@ from torchrec.metrics.rec_metric import RecComputeMode, RecMetric from torchrec.metrics.test_utils import ( metric_test_helper, + rec_metric_gpu_sync_test_launcher, rec_metric_value_test_launcher, + sync_test_helper, TestMetric, ) @@ -71,3 +73,24 @@ def test_fused_ctr(self) -> None: world_size=WORLD_SIZE, entry_point=metric_test_helper, ) + + +class CTRGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = CTRMetric + task_name: str = "ctr" + + def test_sync_ne(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=CTRMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestCTRMetric, + metric_name=CTRGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) diff --git a/torchrec/metrics/tests/test_mae.py b/torchrec/metrics/tests/test_mae.py index 8aaa68af0..cff8bd3f7 100644 --- a/torchrec/metrics/tests/test_mae.py +++ b/torchrec/metrics/tests/test_mae.py @@ -15,7 +15,9 @@ from torchrec.metrics.rec_metric import RecComputeMode, RecMetric from torchrec.metrics.test_utils import ( metric_test_helper, + rec_metric_gpu_sync_test_launcher, rec_metric_value_test_launcher, + sync_test_helper, TestMetric, ) @@ -74,3 +76,24 @@ def test_fused_mae(self) -> None: world_size=WORLD_SIZE, entry_point=metric_test_helper, ) + + +class MAEGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = MAEMetric + task_name: str = "mae" + + def test_sync_ne(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=MAEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestMAEMetric, + metric_name=MAEGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) diff --git a/torchrec/metrics/tests/test_ne.py b/torchrec/metrics/tests/test_ne.py index 2bf94cdc9..6f1ac55c9 100644 --- a/torchrec/metrics/tests/test_ne.py +++ b/torchrec/metrics/tests/test_ne.py @@ -21,7 +21,9 @@ from torchrec.metrics.rec_metric import RecComputeMode, RecMetric from torchrec.metrics.test_utils import ( metric_test_helper, + rec_metric_gpu_sync_test_launcher, rec_metric_value_test_launcher, + sync_test_helper, TestMetric, ) @@ -258,3 +260,24 @@ def test_logloss_update_fused(self) -> None: entry_point=self._logloss_metric_test_helper, batch_window_size=10, ) + + +class NEGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = NEMetric + task_name: str = "ne" + + def test_sync_ne(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=NEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestNEMetric, + metric_name=NEGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + )