Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add GPU sync tests #2194

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions torchrec/metrics/test_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ def rec_metric_gpu_sync_test_launcher(
entry_point: Callable[..., None],
batch_size: int = BATCH_SIZE,
batch_window_size: int = BATCH_WINDOW_SIZE,
**kwargs: Any,
**kwargs: Dict[str, Any],
) -> None:
with tempfile.TemporaryDirectory() as tmpdir:
lc = get_launch_config(
Expand All @@ -385,6 +385,7 @@ def rec_metric_gpu_sync_test_launcher(
should_validate_update,
batch_size,
batch_window_size,
kwargs.get("n_classes", None),
)


Expand All @@ -402,6 +403,7 @@ def sync_test_helper(
batch_window_size: int = BATCH_WINDOW_SIZE,
n_classes: Optional[int] = None,
zero_weights: bool = False,
**kwargs: Dict[str, Any],
) -> None:
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
Expand All @@ -413,13 +415,19 @@ def sync_test_helper(

tasks = gen_test_tasks(task_names)

if n_classes:
# pyre-ignore[6]: Incompatible parameter type
kwargs["number_of_classes"] = n_classes

auc = target_clazz(
world_size=world_size,
batch_size=batch_size,
my_rank=rank,
compute_on_all_ranks=compute_on_all_ranks,
tasks=tasks,
window_size=batch_window_size * world_size,
# pyre-ignore[6]: Incompatible parameter type
**kwargs,
)

weight_value: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -466,10 +474,17 @@ def sync_test_helper(
res = auc.compute()

if rank == 0:
assert torch.allclose(
test_metrics[1][task_names[0]],
res[f"{metric_name}-{task_names[0]}|window_{metric_name}"],
)
# Serving Calibration uses Calibration naming inconsistently
if metric_name == "serving_calibration":
assert torch.allclose(
test_metrics[1][task_names[0]],
res[f"{metric_name}-{task_names[0]}|window_calibration"],
)
else:
assert torch.allclose(
test_metrics[1][task_names[0]],
res[f"{metric_name}-{task_names[0]}|window_{metric_name}"],
)

# we also test the case where other rank has more tensors than rank 0
auc.reset()
Expand All @@ -489,10 +504,17 @@ def sync_test_helper(
res = auc.compute()

if rank == 0:
assert torch.allclose(
test_metrics[1][task_names[0]],
res[f"{metric_name}-{task_names[0]}|window_{metric_name}"],
)
# Serving Calibration uses Calibration naming inconsistently
if metric_name == "serving_calibration":
assert torch.allclose(
test_metrics[1][task_names[0]],
res[f"{metric_name}-{task_names[0]}|window_calibration"],
)
else:
assert torch.allclose(
test_metrics[1][task_names[0]],
res[f"{metric_name}-{task_names[0]}|window_{metric_name}"],
)

dist.destroy_process_group()

Expand Down
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
RecTaskInfo,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -251,3 +253,24 @@ def test_accuracy(self) -> None:
except AssertionError:
print("Assertion error caught with data set ", inputs)
raise


class AccuracyGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = AccuracyMetric
task_name: str = "accuracy"

def test_sync_accuracy(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=AccuracyMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestAccuracyMetric,
metric_name=AccuracyGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_auprc.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
)
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -346,3 +348,24 @@ def test_required_input_for_grouped_auprc(self) -> None:
)

self.assertIn("grouping_keys", auprc.get_required_inputs())


class AUPRCGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = AUPRCMetric
task_name: str = "auprc"

def test_sync_auprc(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=AUPRCMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestAUPRCMetric,
metric_name=AUPRCGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -77,3 +79,24 @@ def test_fused_calibration(self) -> None:
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
)


class CalibrationGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = CalibrationMetric
task_name: str = "calibration"

def test_sync_calibration(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=CalibrationMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestCalibrationMetric,
metric_name=CalibrationGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_ctr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -71,3 +73,24 @@ def test_fused_ctr(self) -> None:
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
)


class CTRGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = CTRMetric
task_name: str = "ctr"

def test_sync_ctr(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=CTRMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestCTRMetric,
metric_name=CTRGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_mae.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -74,3 +76,24 @@ def test_fused_mae(self) -> None:
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
)


class MAEGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = MAEMetric
task_name: str = "mae"

def test_sync_mae(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=MAEMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestMAEMetric,
metric_name=MAEGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
23 changes: 23 additions & 0 deletions torchrec/metrics/tests/test_mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -123,3 +125,24 @@ def test_fused_rmse(self) -> None:
world_size=WORLD_SIZE,
entry_point=metric_test_helper,
)


class MSEGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = MSEMetric
task_name: str = "mse"

def test_sync_mse(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=MSEMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestMSEMetric,
metric_name=MSEGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
)
25 changes: 25 additions & 0 deletions torchrec/metrics/tests/test_multiclass_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from torchrec.metrics.rec_metric import RecComputeMode, RecMetric
from torchrec.metrics.test_utils import (
metric_test_helper,
rec_metric_gpu_sync_test_launcher,
rec_metric_value_test_launcher,
sync_test_helper,
TestMetric,
)

Expand Down Expand Up @@ -113,3 +115,26 @@ def test_multiclass_recall_update_fused(self) -> None:
batch_window_size=10,
n_classes=N_CLASSES,
)


class MulticlassRecallGPUSyncTest(unittest.TestCase):
clazz: Type[RecMetric] = MulticlassRecallMetric
task_name: str = "multiclass_recall"

def test_sync_multiclass_recall(self) -> None:
rec_metric_gpu_sync_test_launcher(
target_clazz=MulticlassRecallMetric,
target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION,
test_clazz=TestMulticlassRecallMetric,
metric_name=MulticlassRecallGPUSyncTest.task_name,
task_names=["t1"],
fused_update_limit=0,
compute_on_all_ranks=False,
should_validate_update=False,
world_size=2,
batch_size=5,
batch_window_size=20,
entry_point=sync_test_helper,
# pyre-ignore[6] Incompatible parameter type
n_classes=N_CLASSES,
)
Loading
Loading