Skip to content

Commit

Permalink
Use httpx thread pool. Avg latency went from ~2.5s to ~0.9s
Browse files Browse the repository at this point in the history
  • Loading branch information
jstray committed Aug 25, 2024
1 parent b310ee0 commit 798c8b5
Showing 1 changed file with 51 additions and 31 deletions.
82 changes: 51 additions & 31 deletions perspective_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import namedtuple
import os
import logging

import time
import httpx
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -17,11 +17,21 @@
)
from prometheus_client import Counter

# each post requires a single request, so see if we can do them all at once
KEEPALIVE_CONNECTIONS = 50

# keep connections a long time to save on tcp connection startup latency
KEEPALIVE_EXPIRY = 60 * 10

dotenv.load_dotenv()
PERSPECTIVE_HOST = os.getenv(
"PERSPECTIVE_HOST", "https://commentanalyzer.googleapis.com"
)

# Create a registry
registry = CollectorRegistry()

# -- Logging --
rank_calls = Counter(
"rank_calls", "Number of calls to the rank endpoint", registry=registry
)
Expand All @@ -42,10 +52,6 @@
logger.setLevel(numeric_level)
logger.info("Starting up")

PERSPECTIVE_HOST = os.getenv(
"PERSPECTIVE_HOST", "https://commentanalyzer.googleapis.com"
)


class EndpointFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
Expand Down Expand Up @@ -81,13 +87,14 @@ async def catch_all_error_handler_middleware(request: Request, call_next):
allow_headers=["*"],
)

# -- Ranking weights --
perspective_baseline = {
"CONSTRUCTIVE_EXPERIMENTAL": 1 / 6,
"PERSONAL_STORY_EXPERIMENTAL": 1 / 6,
"AFFINITY_EXPERIMENTAL": 1 / 6,
"COMPASSION_EXPERIMENTAL": 1 / 6,
"RESPECT_EXPERIMENTAL": 1 / 6,
"CURIOSITY_EXPERIMENTAL": 1 / 6,
"CURIOSITY_EXPERIMENTAL": 1 / 6,
}

perspective_outrage = {
Expand Down Expand Up @@ -137,14 +144,20 @@ async def catch_all_error_handler_middleware(request: Request, call_next):

arms = [perspective_baseline, perspective_outrage, perspective_toxicity]


# -- Main ranker --
class PerspectiveRanker:
ScoredStatement = namedtuple(
"ScoredStatement", "statement attr_scores statement_id scorable"
)

def __init__(self):
self.api_key = os.environ["PERSPECTIVE_API_KEY"]
limits = httpx.Limits(
max_keepalive_connections=KEEPALIVE_CONNECTIONS,
max_connections=None,
keepalive_expiry=KEEPALIVE_EXPIRY,
)
self.httpx_client = httpx.AsyncClient(limits=limits)

# Selects arm based on cohort index
def arm_selection(self, ranking_request):
Expand All @@ -167,31 +180,32 @@ async def score(self, attributes, statement, statement_id):
"languages": ["en"],
"requestedAttributes": {attr: {} for attr in attributes},
}


# don't try to score empty text
if not statement.strip():
return self.ScoredStatement(statement, [], statement_id, False)

logger.info(f"Sending request to Perspective API for statement_id: {statement_id}")
logger.debug(f"Request payload: {data}")
# logger.debug(f"Request payload: {data}") don't log text, it's sensitive

try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{PERSPECTIVE_HOST}/v1alpha1/comments:analyze?key={self.api_key}",
json=data,
headers=headers,
)

response = await self.httpx_client.post(
f"{PERSPECTIVE_HOST}/v1alpha1/comments:analyze?key={self.api_key}",
json=data,
headers=headers,
)

response.raise_for_status()
response_json = response.json()

logger.debug(f"Response for statement_id {statement_id}: {response_json}")

results = []
scorable = True
for attr in attributes:
try:
score = response_json["attributeScores"][attr]["summaryScore"]["value"]

except KeyError:
logger.warning(f"Failed to get score for attribute {attr} in statement_id {statement_id}")
score = 0
Expand All @@ -211,19 +225,7 @@ async def score(self, attributes, statement, statement_id):
logger.error(f"Unexpected error occurred for statement_id {statement_id}: {e}")
raise

async def ranker(self, ranking_request: RankingRequest):
arm_weights = self.arm_selection(ranking_request)
async with asyncio.TaskGroup() as tg:
tasks = [
tg.create_task(self.score(arm_weights, item.text, item.id))
for item in ranking_request.items
]

scored_statements = await asyncio.gather(*tasks)
return self.arm_sort(arm_weights, scored_statements)

def arm_sort(self, arm_weightings, scored_statements):

reordered_statements = []

last_score = 0
Expand All @@ -249,15 +251,33 @@ def arm_sort(self, arm_weightings, scored_statements):
}
return result

async def rank(self, ranking_request: RankingRequest):
arm_weights = self.arm_selection(ranking_request)
tasks = [
self.score(arm_weights, item.text, item.id)
for item in ranking_request.items
]
scored_statements = await asyncio.gather(*tasks)
return self.arm_sort(arm_weights, scored_statements)


# Global singleton, so that all calls share the same httpx client
ranker = PerspectiveRanker()


@app.post("/rank")
async def main(ranking_request: RankingRequest) -> RankingResponse:
try:
ranker = PerspectiveRanker()
results = await ranker.ranker(ranking_request)
start_time = time.time()

results = await ranker.rank(ranking_request)

latency = time.time() - start_time
logger.debug(f"ranking results: {results}")
logger.debug(f"ranking took time: {latency}")
rank_calls.inc()
return RankingResponse(ranked_ids=results["ranked_ids"])

except Exception as e:
exceptions_count.inc()
logger.error("Error in rank endpoint:", exc_info=True)
Expand Down

0 comments on commit 798c8b5

Please sign in to comment.