From e26f5954db06bde39eda7de9c5740f161997d072 Mon Sep 17 00:00:00 2001 From: Ian Baker Date: Wed, 31 Jul 2024 23:17:45 -0700 Subject: [PATCH] Rename cohorts to what we are using in the router config --- perspective_ranker.py | 20 ++++++++++---------- perspective_ranker_test.py | 10 ++++++---- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/perspective_ranker.py b/perspective_ranker.py index 24a5f74..ef713be 100644 --- a/perspective_ranker.py +++ b/perspective_ranker.py @@ -27,7 +27,7 @@ allow_headers=["*"], ) -arm_1 = [ +perspective_baseline = [ "CONSTRUCTIVE_EXPERIMENTAL", "PERSONAL_STORY_EXPERIMENTAL", "AFFINITY_EXPERIMENTAL", @@ -36,7 +36,7 @@ "CURIOSITY_EXPERIMENTAL", ] -arm_2 = [ +perspective_outrage = [ "CONSTRUCTIVE_EXPERIMENTAL", "PERSONAL_STORY_EXPERIMENTAL", "AFFINITY_EXPERIMENTAL", @@ -50,7 +50,7 @@ "ALIENATION_EXPERIMENTAL", ] -arm_3 = [ +perspective_toxicity = [ "CONSTRUCTIVE_EXPERIMENTAL", "PERSONAL_STORY_EXPERIMENTAL", "AFFINITY_EXPERIMENTAL", @@ -63,7 +63,7 @@ "THREAT", ] -arms = [arm_1, arm_2, arm_3] +arms = [perspective_baseline, perspective_outrage, perspective_toxicity] class PerspectiveRanker: @@ -77,12 +77,12 @@ def __init__(self): # Selects arm based on cohort index def arm_selection(self, ranking_request): cohort = ranking_request.session.cohort - if cohort == "arm1": - return arm_1 - elif cohort == "arm2": - return arm_2 - elif cohort == "arm3": - return arm_3 + if cohort == "perspective_baseline": + return perspective_baseline + elif cohort == "perspective_outrage": + return perspective_outrage + elif cohort == "perspective_toxicity": + return perspective_toxicity else: raise ValueError(f"Unknown cohort: {cohort}") diff --git a/perspective_ranker_test.py b/perspective_ranker_test.py index 02ad581..12727ef 100644 --- a/perspective_ranker_test.py +++ b/perspective_ranker_test.py @@ -42,10 +42,12 @@ def mock_perspective_build(attributes): def test_rank(client): comments = fake_request(n_posts=1, n_comments=2) - comments.session.cohort = "arm1" + comments.session.cohort = "perspective_baseline" with patch("perspective_ranker.discovery") as mock_discovery: - mock_discovery.build = mock_perspective_build(perspective_ranker.arm_1) + mock_discovery.build = mock_perspective_build( + perspective_ranker.perspective_baseline + ) response = client.post("/rank", json=jsonable_encoder(comments)) # Check if the request was successful (status code 200) @@ -59,10 +61,10 @@ def test_rank(client): def test_arm_selection(): rank = perspective_ranker.PerspectiveRanker() comments = fake_request(n_posts=1, n_comments=2) - comments.session.cohort = "arm1" + comments.session.cohort = "perspective_baseline" result = rank.arm_selection(comments) - assert result == perspective_ranker.arm_1 + assert result == perspective_ranker.perspective_baseline def test_sync_score():