Skip to content

Commit

Permalink
Merge pull request #3 from HumanCompatibleAI/arm-names
Browse files Browse the repository at this point in the history
Rename cohorts to what we are using in the router config
  • Loading branch information
JACProjec authored Aug 2, 2024
2 parents e225a69 + e26f595 commit 9fe5a17
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
20 changes: 10 additions & 10 deletions perspective_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
allow_headers=["*"],
)

arm_1 = [
perspective_baseline = [
"CONSTRUCTIVE_EXPERIMENTAL",
"PERSONAL_STORY_EXPERIMENTAL",
"AFFINITY_EXPERIMENTAL",
Expand All @@ -36,7 +36,7 @@
"CURIOSITY_EXPERIMENTAL",
]

arm_2 = [
perspective_outrage = [
"CONSTRUCTIVE_EXPERIMENTAL",
"PERSONAL_STORY_EXPERIMENTAL",
"AFFINITY_EXPERIMENTAL",
Expand All @@ -50,7 +50,7 @@
"ALIENATION_EXPERIMENTAL",
]

arm_3 = [
perspective_toxicity = [
"CONSTRUCTIVE_EXPERIMENTAL",
"PERSONAL_STORY_EXPERIMENTAL",
"AFFINITY_EXPERIMENTAL",
Expand All @@ -63,7 +63,7 @@
"THREAT",
]

arms = [arm_1, arm_2, arm_3]
arms = [perspective_baseline, perspective_outrage, perspective_toxicity]


class PerspectiveRanker:
Expand All @@ -77,12 +77,12 @@ def __init__(self):
# Selects arm based on cohort index
def arm_selection(self, ranking_request):
cohort = ranking_request.session.cohort
if cohort == "arm1":
return arm_1
elif cohort == "arm2":
return arm_2
elif cohort == "arm3":
return arm_3
if cohort == "perspective_baseline":
return perspective_baseline
elif cohort == "perspective_outrage":
return perspective_outrage
elif cohort == "perspective_toxicity":
return perspective_toxicity
else:
raise ValueError(f"Unknown cohort: {cohort}")

Expand Down
10 changes: 6 additions & 4 deletions perspective_ranker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ def mock_perspective_build(attributes):

def test_rank(client):
comments = fake_request(n_posts=1, n_comments=2)
comments.session.cohort = "arm1"
comments.session.cohort = "perspective_baseline"

with patch("perspective_ranker.discovery") as mock_discovery:
mock_discovery.build = mock_perspective_build(perspective_ranker.arm_1)
mock_discovery.build = mock_perspective_build(
perspective_ranker.perspective_baseline
)

response = client.post("/rank", json=jsonable_encoder(comments))
# Check if the request was successful (status code 200)
Expand All @@ -59,10 +61,10 @@ def test_rank(client):
def test_arm_selection():
rank = perspective_ranker.PerspectiveRanker()
comments = fake_request(n_posts=1, n_comments=2)
comments.session.cohort = "arm1"
comments.session.cohort = "perspective_baseline"
result = rank.arm_selection(comments)

assert result == perspective_ranker.arm_1
assert result == perspective_ranker.perspective_baseline


def test_sync_score():
Expand Down

0 comments on commit 9fe5a17

Please sign in to comment.