Skip to content

Commit

Permalink
Merge pull request #7 from HumanCompatibleAI/unscorable-statements
Browse files Browse the repository at this point in the history
gracefully handle un-scorable statements, like those in non-english languages
  • Loading branch information
raindrift authored Aug 15, 2024
2 parents 9789581 + e4daec0 commit 1df78db
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 42 deletions.
48 changes: 33 additions & 15 deletions perspective_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def log_exceptions_middleware(request: Request, call_next):

class PerspectiveRanker:
ScoredStatement = namedtuple(
"ScoredStatement", "statement attr_scores statement_id"
"ScoredStatement", "statement attr_scores statement_id scorable"
)

def __init__(self):
Expand All @@ -116,17 +116,27 @@ async def score(self, attributes, statement, statement_id):
"languages": ["en"],
"requestedAttributes": {attr: {} for attr in attributes},
}

response = httpx.post(
f"{PERSPECTIVE_HOST}/v1alpha1/comments:analyze?key={self.api_key}",
json=data,
headers=headers,
).json()
results = [
(attr, response["attributeScores"][attr]["summaryScore"]["value"])
for attr in attributes
]

result = self.ScoredStatement(statement, results, statement_id)
results = []
scorable = True
for attr in attributes:
try:
score = response["attributeScores"][attr]["summaryScore"]["value"]
except KeyError:
score = 0 # for now, set the score to 0 if it wasn't possible get a score
scorable = False

results.append(
(attr, score)
)

result = self.ScoredStatement(statement, results, statement_id, scorable)

return result

Expand All @@ -138,10 +148,10 @@ async def ranker(self, ranking_request: RankingRequest):
for item in ranking_request.items
]

results = await asyncio.gather(*tasks)
return self.arm_sort(results)
scored_statements = await asyncio.gather(*tasks)
return self.arm_sort(scored_statements)

def arm_sort(self, results):
def arm_sort(self, scored_statements):
weightings = {
"CONSTRUCTIVE_EXPERIMENTAL": 1 / 6,
"PERSONAL_STORY_EXPERIMENTAL": 1 / 6,
Expand All @@ -162,12 +172,20 @@ def arm_sort(self, results):

reordered_statements = []

for named_tuple in results:
combined_score = 0
for group in named_tuple.attr_scores:
attribute, score = group
combined_score += weightings[attribute] * score
reordered_statements.append((named_tuple.statement_id, combined_score))
last_score = 0

for statement in scored_statements:
if statement.scorable:
combined_score = 0
for group in statement.attr_scores:
attribute, score = group
combined_score += weightings[attribute] * score
else:
# if a statement is not scorable, keep it with its neighbor. this prevents us from collecting
# all unscorable statements at one end of the ranking.
combined_score = last_score
reordered_statements.append((statement.statement_id, combined_score))
last_score = combined_score

reordered_statements.sort(key=lambda x: x[1], reverse=True)
filtered = [x[0] for x in reordered_statements]
Expand Down
77 changes: 51 additions & 26 deletions perspective_ranker_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import os
import pytest
from unittest.mock import patch, Mock

from fastapi.encoders import jsonable_encoder
from fastapi.testclient import TestClient
import respx

import perspective_ranker
from ranking_challenge.fake import fake_request
Expand All @@ -19,7 +21,7 @@ def client(app):
return TestClient(app)


def mock_perspective_build(attributes):
def api_response(attributes):
api_response = {"attributeScores": {}}

for attr in attributes:
Expand All @@ -29,33 +31,43 @@ def mock_perspective_build(attributes):
}
}

config = {
"comments.return_value.analyze.return_value.execute.return_value": api_response
}
mock_client = Mock()
mock_client.configure_mock(**config)
mock_build = Mock()
mock_build.return_value = mock_client
return api_response


return mock_build
PERSPECTIVE_URL = f"{perspective_ranker.PERSPECTIVE_HOST}/v1alpha1/comments:analyze?key={os.environ["PERSPECTIVE_API_KEY"]}"


@respx.mock
def test_rank(client):
comments = fake_request(n_posts=1, n_comments=2)
comments.session.cohort = "perspective_baseline"

with patch("perspective_ranker.discovery") as mock_discovery:
mock_discovery.build = mock_perspective_build(
perspective_ranker.perspective_baseline
)
respx.post(PERSPECTIVE_URL).respond(
json=api_response(perspective_ranker.perspective_baseline)
)

response = client.post("/rank", json=jsonable_encoder(comments))
# Check if the request was successful (status code 200)
if response.status_code != 200:
assert False, f"Request failed with status code: {response.status_code}"

response = client.post("/rank", json=jsonable_encoder(comments))
# Check if the request was successful (status code 200)
if response.status_code != 200:
assert False, f"Request failed with status code: {response.status_code}"
result = response.json()
assert len(result["ranked_ids"]) == 3

@respx.mock
def test_rank_no_score(client):
comments = fake_request(n_posts=1, n_comments=2)
comments.session.cohort = "perspective_baseline"

result = response.json()
assert len(result["ranked_ids"]) == 3
respx.post(PERSPECTIVE_URL).respond(json={})

response = client.post("/rank", json=jsonable_encoder(comments))
# Check if the request was successful (status code 200)
if response.status_code != 200:
assert False, f"Request failed with status code: {response.status_code}"

result = response.json()
assert len(result["ranked_ids"]) == 3


def test_arm_selection():
Expand All @@ -67,17 +79,20 @@ def test_arm_selection():
assert result == perspective_ranker.perspective_baseline


def test_sync_score():
@respx.mock
@pytest.mark.asyncio
async def test_score():
rank = perspective_ranker.PerspectiveRanker()

with patch("perspective_ranker.discovery") as mock_discovery:
mock_discovery.build = mock_perspective_build(["TOXICITY"])
respx.post(PERSPECTIVE_URL).respond(
json=api_response(["TOXICITY"])
)

result = rank.sync_score(["TOXICITY"], "Test statement", "test_statement_id")
result = await rank.score(["TOXICITY"], "Test statement", "test_statement_id")

assert result.attr_scores == [("TOXICITY", 0.5)]
assert result.statement == "Test statement"
assert result.statement_id == "test_statement_id"
assert result.attr_scores == [("TOXICITY", 0.5)]
assert result.statement == "Test statement"
assert result.statement_id == "test_statement_id"


def test_arm_sort():
Expand All @@ -88,16 +103,25 @@ def test_arm_sort():
"Test statement 2",
[("TOXICITY", 0.6), ("CONSTRUCTIVE_EXPERIMENTAL", 0.2)],
"test_statement_id_2",
True,
),
rank.ScoredStatement(
"Test statement",
[("TOXICITY", 0.1), ("CONSTRUCTIVE_EXPERIMENTAL", 0.1)],
"test_statement_id_1",
True,
),
rank.ScoredStatement(
"Test statement",
[("TOXICITY", 0), ("CONSTRUCTIVE_EXPERIMENTAL", 0)],
"test_statement_id_unscorable",
False,
),
rank.ScoredStatement(
"Test statement 3",
[("TOXICITY", 0.9), ("CONSTRUCTIVE_EXPERIMENTAL", 0.3)],
"test_statement_id_3",
True,
),
]

Expand All @@ -106,6 +130,7 @@ def test_arm_sort():
assert result == {
"ranked_ids": [
"test_statement_id_1",
"test_statement_id_unscorable",
"test_statement_id_2",
"test_statement_id_3",
]
Expand Down
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ faker = "^27.0.0"

[tool.poetry.dev-dependencies]
pytest = "^8.3"
respx = "*"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down

0 comments on commit 1df78db

Please sign in to comment.