Skip to content

Commit

Permalink
merge parallelize
Browse files Browse the repository at this point in the history
  • Loading branch information
hiftikha committed Aug 25, 2024
2 parents 81f2eaf + 798c8b5 commit 5db6053
Showing 1 changed file with 30 additions and 29 deletions.
59 changes: 30 additions & 29 deletions perspective_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import logging
import time

import httpx
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
Expand All @@ -18,7 +17,16 @@
)
from prometheus_client import Counter, Histogram

# each post requires a single request, so see if we can do them all at once
KEEPALIVE_CONNECTIONS = 50

# keep connections a long time to save on tcp connection startup latency
KEEPALIVE_EXPIRY = 60 * 10

dotenv.load_dotenv()
PERSPECTIVE_HOST = os.getenv(
"PERSPECTIVE_HOST", "https://commentanalyzer.googleapis.com"
)

# Create a registry
registry = CollectorRegistry()
Expand Down Expand Up @@ -72,10 +80,6 @@
logger.setLevel(numeric_level)
logger.info("Starting up")

PERSPECTIVE_HOST = os.getenv(
"PERSPECTIVE_HOST", "https://commentanalyzer.googleapis.com"
)


class EndpointFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
Expand Down Expand Up @@ -111,13 +115,14 @@ async def catch_all_error_handler_middleware(request: Request, call_next):
allow_headers=["*"],
)

# -- Ranking weights --
perspective_baseline = {
"CONSTRUCTIVE_EXPERIMENTAL": 1 / 6,
"PERSONAL_STORY_EXPERIMENTAL": 1 / 6,
"AFFINITY_EXPERIMENTAL": 1 / 6,
"COMPASSION_EXPERIMENTAL": 1 / 6,
"RESPECT_EXPERIMENTAL": 1 / 6,
"CURIOSITY_EXPERIMENTAL": 1 / 6,
"CURIOSITY_EXPERIMENTAL": 1 / 6,
}

perspective_outrage = {
Expand Down Expand Up @@ -167,14 +172,20 @@ async def catch_all_error_handler_middleware(request: Request, call_next):

arms = [perspective_baseline, perspective_outrage, perspective_toxicity]


# -- Main ranker --
class PerspectiveRanker:
ScoredStatement = namedtuple(
"ScoredStatement", "statement attr_scores statement_id scorable"
)

def __init__(self):
self.api_key = os.environ["PERSPECTIVE_API_KEY"]
limits = httpx.Limits(
max_keepalive_connections=KEEPALIVE_CONNECTIONS,
max_connections=None,
keepalive_expiry=KEEPALIVE_EXPIRY,
)
self.httpx_client = httpx.AsyncClient(limits=limits)

# Selects arm based on cohort index
def arm_selection(self, ranking_request):
Expand All @@ -197,31 +208,32 @@ async def score(self, attributes, statement, statement_id):
"languages": ["en"],
"requestedAttributes": {attr: {} for attr in attributes},
}


# don't try to score empty text
if not statement.strip():
return self.ScoredStatement(statement, [], statement_id, False)

logger.info(f"Sending request to Perspective API for statement_id: {statement_id}")
logger.debug(f"Request payload: {data}")
# logger.debug(f"Request payload: {data}") don't log text, it's sensitive

try:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{PERSPECTIVE_HOST}/v1alpha1/comments:analyze?key={self.api_key}",
json=data,
headers=headers,
)

response = await self.httpx_client.post(
f"{PERSPECTIVE_HOST}/v1alpha1/comments:analyze?key={self.api_key}",
json=data,
headers=headers,
)

response.raise_for_status()
response_json = response.json()

logger.debug(f"Response for statement_id {statement_id}: {response_json}")

results = []
scorable = True
for attr in attributes:
try:
score = response_json["attributeScores"][attr]["summaryScore"]["value"]

except KeyError:
logger.warning(f"Failed to get score for attribute {attr} in statement_id {statement_id}")
score = 0
Expand All @@ -241,19 +253,7 @@ async def score(self, attributes, statement, statement_id):
logger.error(f"Unexpected error occurred for statement_id {statement_id}: {e}")
raise

async def ranker(self, ranking_request: RankingRequest):
arm_weights = self.arm_selection(ranking_request)
async with asyncio.TaskGroup() as tg:
tasks = [
tg.create_task(self.score(arm_weights, item.text, item.id))
for item in ranking_request.items
]

scored_statements = await asyncio.gather(*tasks)
return self.arm_sort(arm_weights, scored_statements)

def arm_sort(self, arm_weightings, scored_statements):

reordered_statements = []

last_score = 0
Expand Down Expand Up @@ -328,6 +328,7 @@ async def main(ranking_request: RankingRequest) -> RankingResponse:
ranking_latency.observe(latency)

return RankingResponse(ranked_ids=results["ranked_ids"])

except Exception as e:
exceptions_count.inc()
logger.error("Error in rank endpoint:", exc_info=True)
Expand Down

0 comments on commit 5db6053

Please sign in to comment.