Skip to content

Commit

Permalink
Merge pull request #10 from HumanCompatibleAI/parallelize
Browse files Browse the repository at this point in the history
Use httpx thread pool. Avg latency went from ~2.5s to ~0.9s
  • Loading branch information
jstray authored Aug 25, 2024
2 parents b310ee0 + 390b946 commit c76fc95
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 40 deletions.
135 changes: 103 additions & 32 deletions perspective_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import namedtuple
import os
import logging

import time
import httpx
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -15,19 +15,57 @@
expose_metrics,
CollectorRegistry,
)
from prometheus_client import Counter
from prometheus_client import Counter, Histogram

# each post requires a single request, so see if we can do them all at once
KEEPALIVE_CONNECTIONS = 50

# keep connections a long time to save on tcp connection startup latency
KEEPALIVE_EXPIRY = 60 * 10

dotenv.load_dotenv()
PERSPECTIVE_HOST = os.getenv(
"PERSPECTIVE_HOST", "https://commentanalyzer.googleapis.com"
)

# Create a registry
registry = CollectorRegistry()

# -- Metrics --
rank_calls = Counter(
"rank_calls", "Number of calls to the rank endpoint", registry=registry
)
exceptions_count = Counter(
"exceptions_count", "Number of unhandled exceptions", registry=registry
)
ranking_latency = Histogram(
"ranking_latency_seconds",
"Latency of ranking operations in seconds",
registry=registry
)
score_distribution = Histogram(
"score_distribution",
"Distribution of scores for ranked items",
buckets=(0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1),
registry=registry
)
unscorable_items = Counter(
"unscorable_items",
"Number of items that could not be scored",
registry=registry
)
items_per_request = Histogram(
"items_per_request",
"Number of items per ranking request",
buckets=(1, 5, 10, 20, 50, 100, 200, 500),
registry=registry
)
cohort_distribution = Counter(
"cohort_distribution",
"Distribution of requests across different cohorts",
["cohort"],
registry=registry
)

logging.basicConfig(
level=logging.INFO,
Expand All @@ -42,10 +80,6 @@
logger.setLevel(numeric_level)
logger.info("Starting up")

PERSPECTIVE_HOST = os.getenv(
"PERSPECTIVE_HOST", "https://commentanalyzer.googleapis.com"
)


class EndpointFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
Expand Down Expand Up @@ -81,13 +115,14 @@ async def catch_all_error_handler_middleware(request: Request, call_next):
allow_headers=["*"],
)

# -- Ranking weights --
perspective_baseline = {
"CONSTRUCTIVE_EXPERIMENTAL": 1 / 6,
"PERSONAL_STORY_EXPERIMENTAL": 1 / 6,
"AFFINITY_EXPERIMENTAL": 1 / 6,
"COMPASSION_EXPERIMENTAL": 1 / 6,
"RESPECT_EXPERIMENTAL": 1 / 6,
"CURIOSITY_EXPERIMENTAL": 1 / 6,
"CURIOSITY_EXPERIMENTAL": 1 / 6,
}

perspective_outrage = {
Expand Down Expand Up @@ -137,14 +172,20 @@ async def catch_all_error_handler_middleware(request: Request, call_next):

arms = [perspective_baseline, perspective_outrage, perspective_toxicity]


# -- Main ranker --
class PerspectiveRanker:
ScoredStatement = namedtuple(
"ScoredStatement", "statement attr_scores statement_id scorable"
)

def __init__(self):
self.api_key = os.environ["PERSPECTIVE_API_KEY"]
limits = httpx.Limits(
max_keepalive_connections=KEEPALIVE_CONNECTIONS,
max_connections=None,
keepalive_expiry=KEEPALIVE_EXPIRY,
)
self.httpx_client = httpx.AsyncClient(limits=limits)

# Selects arm based on cohort index
def arm_selection(self, ranking_request):
Expand All @@ -167,31 +208,32 @@ async def score(self, attributes, statement, statement_id):
"languages": ["en"],
"requestedAttributes": {attr: {} for attr in attributes},
}


# don't try to score empty text
if not statement.strip():
return self.ScoredStatement(statement, [], statement_id, False)

logger.info(f"Sending request to Perspective API for statement_id: {statement_id}")
logger.debug(f"Request payload: {data}")
# logger.debug(f"Request payload: {data}") don't log text, it's sensitive

try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{PERSPECTIVE_HOST}/v1alpha1/comments:analyze?key={self.api_key}",
json=data,
headers=headers,
)

response = await self.httpx_client.post(
f"{PERSPECTIVE_HOST}/v1alpha1/comments:analyze?key={self.api_key}",
json=data,
headers=headers,
)

response.raise_for_status()
response_json = response.json()

logger.debug(f"Response for statement_id {statement_id}: {response_json}")

results = []
scorable = True
for attr in attributes:
try:
score = response_json["attributeScores"][attr]["summaryScore"]["value"]

except KeyError:
logger.warning(f"Failed to get score for attribute {attr} in statement_id {statement_id}")
score = 0
Expand All @@ -211,19 +253,7 @@ async def score(self, attributes, statement, statement_id):
logger.error(f"Unexpected error occurred for statement_id {statement_id}: {e}")
raise

async def ranker(self, ranking_request: RankingRequest):
arm_weights = self.arm_selection(ranking_request)
async with asyncio.TaskGroup() as tg:
tasks = [
tg.create_task(self.score(arm_weights, item.text, item.id))
for item in ranking_request.items
]

scored_statements = await asyncio.gather(*tasks)
return self.arm_sort(arm_weights, scored_statements)

def arm_sort(self, arm_weightings, scored_statements):

reordered_statements = []

last_score = 0
Expand All @@ -246,18 +276,59 @@ def arm_sort(self, arm_weightings, scored_statements):

result = {
"ranked_ids": filtered,
"ranked_ids_with_scores": reordered_statements, # Include scores
}
return result

async def rank(self, ranking_request: RankingRequest):
arm_weights = self.arm_selection(ranking_request)

# Record cohort distribution
cohort_distribution.labels(ranking_request.session.cohort).inc()

# Record number of items per request
items_per_request.observe(len(ranking_request.items))

tasks = [
self.score(arm_weights, item.text, item.id)
for item in ranking_request.items
]
scored_statements = await asyncio.gather(*tasks)

# Count unscorable items
unscorable_count = sum(1 for statement in scored_statements if not statement.scorable)
unscorable_items.inc(unscorable_count)

result = self.arm_sort(arm_weights, scored_statements)

# Record score distribution
for _, score in result['ranked_ids_with_scores']:
score_distribution.observe(score)

return result


# Global singleton, so that all calls share the same httpx client
ranker = PerspectiveRanker()


@app.post("/rank")
async def main(ranking_request: RankingRequest) -> RankingResponse:
try:
ranker = PerspectiveRanker()
results = await ranker.ranker(ranking_request)
start_time = time.time()

results = await ranker.rank(ranking_request)

latency = time.time() - start_time
logger.debug(f"ranking results: {results}")
logger.debug(f"ranking took time: {latency}")

# Record metrics
rank_calls.inc()
ranking_latency.observe(latency)

return RankingResponse(ranked_ids=results["ranked_ids"])

except Exception as e:
exceptions_count.inc()
logger.error("Error in rank endpoint:", exc_info=True)
Expand Down
14 changes: 6 additions & 8 deletions perspective_ranker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,9 @@ def test_arm_sort():

result = rank.arm_sort(perspective_ranker.perspective_toxicity, scored_statements)

assert result == {
"ranked_ids": [
"test_statement_id_1",
"test_statement_id_unscorable",
"test_statement_id_2",
"test_statement_id_3",
]
}
assert result["ranked_ids"] == [
"test_statement_id_1",
"test_statement_id_unscorable",
"test_statement_id_2",
"test_statement_id_3",
]

0 comments on commit c76fc95

Please sign in to comment.