Skip to content

Commit

Permalink
Merge pull request #11 from HumanCompatibleAI/add_middleware
Browse files Browse the repository at this point in the history
add latency custom metric to jigsaw
  • Loading branch information
jstray authored Aug 25, 2024
2 parents 798c8b5 + 5db6053 commit 747e5b9
Showing 1 changed file with 54 additions and 3 deletions.
57 changes: 54 additions & 3 deletions perspective_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
expose_metrics,
CollectorRegistry,
)
from prometheus_client import Counter
from prometheus_client import Counter, Histogram

# each post requires a single request, so see if we can do them all at once
KEEPALIVE_CONNECTIONS = 50
Expand All @@ -31,13 +31,41 @@
# Create a registry
registry = CollectorRegistry()

# -- Logging --
# -- Metrics --
rank_calls = Counter(
"rank_calls", "Number of calls to the rank endpoint", registry=registry
)
exceptions_count = Counter(
"exceptions_count", "Number of unhandled exceptions", registry=registry
)
ranking_latency = Histogram(
"ranking_latency_seconds",
"Latency of ranking operations in seconds",
registry=registry
)
score_distribution = Histogram(
"score_distribution",
"Distribution of scores for ranked items",
buckets=(0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1),
registry=registry
)
unscorable_items = Counter(
"unscorable_items",
"Number of items that could not be scored",
registry=registry
)
items_per_request = Histogram(
"items_per_request",
"Number of items per ranking request",
buckets=(1, 5, 10, 20, 50, 100, 200, 500),
registry=registry
)
cohort_distribution = Counter(
"cohort_distribution",
"Distribution of requests across different cohorts",
["cohort"],
registry=registry
)

logging.basicConfig(
level=logging.INFO,
Expand Down Expand Up @@ -248,17 +276,36 @@ def arm_sort(self, arm_weightings, scored_statements):

result = {
"ranked_ids": filtered,
"ranked_ids_with_scores": reordered_statements, # Include scores
}
return result

async def rank(self, ranking_request: RankingRequest):
arm_weights = self.arm_selection(ranking_request)

# Record cohort distribution
cohort_distribution.labels(ranking_request.session.cohort).inc()

# Record number of items per request
items_per_request.observe(len(ranking_request.items))

tasks = [
self.score(arm_weights, item.text, item.id)
for item in ranking_request.items
]
scored_statements = await asyncio.gather(*tasks)
return self.arm_sort(arm_weights, scored_statements)

# Count unscorable items
unscorable_count = sum(1 for statement in scored_statements if not statement.scorable)
unscorable_items.inc(unscorable_count)

result = self.arm_sort(arm_weights, scored_statements)

# Record score distribution
for _, score in result['ranked_ids_with_scores']:
score_distribution.observe(score)

return result


# Global singleton, so that all calls share the same httpx client
Expand All @@ -275,7 +322,11 @@ async def main(ranking_request: RankingRequest) -> RankingResponse:
latency = time.time() - start_time
logger.debug(f"ranking results: {results}")
logger.debug(f"ranking took time: {latency}")

# Record metrics
rank_calls.inc()
ranking_latency.observe(latency)

return RankingResponse(ranked_ids=results["ranked_ids"])

except Exception as e:
Expand Down

0 comments on commit 747e5b9

Please sign in to comment.