Skip to content

Commit

Permalink
add GPU sync and numerical value tests (#2194)
Browse files Browse the repository at this point in the history
Summary:

Added GPU sync tests to simulate gathering metric states on to rank 0 and computing. Tests don't cover this case before, which has resulted in SEVs in the past as users aren't aware of how RecMetrics collects and computes metrics.

Added numerical value tests, most metrics are do not have this which can result in issues down the line if metrics need to be changed/accommodate future changes. Also we've found inconsistencies sometimes from other methods, so always good to check here. We compare each metric to a reference implementation from literature to ensure the values are as expected.

Fixed and cleaned up some tests too.

Reviewed By: henrylhtsang

Differential Revision: D59173140
  • Loading branch information
iamzainhuda authored and facebook-github-bot committed Oct 15, 2024
1 parent b6e784e commit b482bb5
Show file tree
Hide file tree
Showing 14 changed files with 496 additions and 181 deletions.
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
RecTaskInfo,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -251,3 +253,24 @@ def test_accuracy(self) -> None:
except AssertionError:
print("Assertion error caught with data set ", inputs)
raise


class AccuracyGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = AccuracyMetric
task_name: str = "accuracy"

def test_sync_ne(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=AccuracyMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestAccuracyMetric,
metric_name=AccuracyGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_auprc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
)
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -346,3 +348,24 @@ def test_required_input_for_grouped_auprc(self) -> None:
)

self.assertIn("grouping_keys", auprc.get_required_inputs())


class AUPRCGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = AUPRCMetric
task_name: str = "auprc"

def test_sync_auprc(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=AUPRCMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestAUPRCMetric,
metric_name=AUPRCGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -77,3 +79,24 @@ def test_fused_calibration(self) -> None:
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
)


class CalibrationGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = CalibrationMetric
task_name: str = "calibration"

def test_sync_ne(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=CalibrationMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestCalibrationMetric,
metric_name=CalibrationGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_ctr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -71,3 +73,24 @@ def test_fused_ctr(self) -> None:
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
)


class CTRGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = CTRMetric
task_name: str = "ctr"

def test_sync_ne(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=CTRMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestCTRMetric,
metric_name=CTRGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -74,3 +76,24 @@ def test_fused_mae(self) -> None:
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
)


class MAEGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = MAEMetric
task_name: str = "mae"

def test_sync_ne(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=MAEMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestMAEMetric,
metric_name=MAEGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -123,3 +125,24 @@ def test_fused_rmse(self) -> None:
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
)


class MSEGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = MSEMetric
task_name: str = "mse"

def test_sync_ne(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=MSEMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestMSEMetric,
metric_name=MSEGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_multiclass_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -113,3 +115,24 @@ def test_multiclass_recall_update_fused(self) -> None:
batch_window_size=10,
n_classes=N_CLASSES,
)


class MulticlassRecallGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = MulticlassRecallMetric
task_name: str = "accuracy"

def test_sync_ne(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=MulticlassRecallMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestMulticlassRecallMetric,
metric_name=MulticlassRecallGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
Loading

0 comments on commit b482bb5

Please sign in to comment.