Skip to content

Commit

Permalink
Merge pull request #1072 from parea-ai/add-retry-backoff
Browse files Browse the repository at this point in the history
add broader retry on http client
  • Loading branch information
jalexanderII authored Aug 22, 2024
2 parents 4fd8e40 + 6b6a0ba commit 44ecbf2
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@
"\n",
"\n",
"dataset = to_simple_dictionary(comments_df)\n",
"dataset[0]"
"dataset"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dotenv import load_dotenv

from parea import Parea, get_current_trace_id, trace
from parea.schemas import Completion, CompletionResponse, LLMInputs, Message, Role
from parea.schemas import Completion, CompletionResponse, FeedbackRequest, LLMInputs, Message, Role

load_dotenv()

Expand Down Expand Up @@ -101,10 +101,10 @@ def deployed_argument_chain_tags_metadata(query: str, additional_description: st
additional_description="Provide a concise, few sentence argument on why coffee is good for you.",
)
print(json.dumps(asdict(result2), indent=2))
# p.record_feedback(
# FeedbackRequest(
# trace_id=trace_id,
# score=0.7, # 0.0 (bad) to 1.0 (good)
# target="Coffee is wonderful. End of story.",
# )
# )
p.record_feedback(
FeedbackRequest(
trace_id=trace_id,
score=0.7, # 0.0 (bad) to 1.0 (good)
target="Coffee is wonderful. End of story.",
)
)
135 changes: 81 additions & 54 deletions parea/api_client.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,17 @@
from typing import Any, AsyncIterable, Callable, Dict, List, Optional
from typing import Any, AsyncIterable, Dict, List, Optional

import asyncio
import json
import logging
import os
import time
from functools import wraps
from importlib import metadata as importlib_metadata

import httpx
from dotenv import load_dotenv
from tenacity import retry, stop_after_attempt, wait_exponential

load_dotenv()

MAX_RETRIES = 8
BACKOFF_FACTOR = 0.5


def retry_on_502(func: Callable[..., Any]) -> Callable[..., Any]:
"""
A decorator to retry a function or coroutine on encountering a 502 error.
Parameters:
- func: The function or coroutine to be decorated.
Returns:
- A wrapper function that incorporates retry logic.
"""

@wraps(func)
async def async_wrapper(*args, **kwargs):
for retry in range(MAX_RETRIES):
try:
return await func(*args, **kwargs)
except httpx.HTTPError as e:
if not _should_retry(e, retry):
raise
await asyncio.sleep(BACKOFF_FACTOR * (2**retry))

@wraps(func)
def sync_wrapper(*args, **kwargs):
for retry in range(MAX_RETRIES):
try:
return func(*args, **kwargs)
except httpx.HTTPError as e:
if not _should_retry(e, retry):
raise
time.sleep(BACKOFF_FACTOR * (2**retry))

def _should_retry(error, current_retry):
"""Determines if the function should retry on error."""
is_502_error = isinstance(error, httpx.HTTPStatusError) and error.response.status_code == 502
is_last_retry = current_retry == MAX_RETRIES - 1
return not is_last_retry and (isinstance(error, (httpx.ConnectError, httpx.ReadError, httpx.RemoteProtocolError)) or is_502_error)

if asyncio.iscoroutinefunction(func):
return async_wrapper
else:
return sync_wrapper
logger = logging.getLogger()


class HTTPClient:
Expand Down Expand Up @@ -87,7 +44,7 @@ def _get_headers(self, api_key: Optional[str] = None) -> Dict[str, str]:
headers["x-sdk-integrations"] = ",".join(self.integrations)
return headers

@retry_on_502
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=4, max=10))
def request(
self,
method: str,
Expand All @@ -108,9 +65,25 @@ def request(
if e.response.status_code == 422:
# update the error message to include the validation errors
e.args = (f"{e.args[0]}: {e.response.json()}",)
logger.error(
f"HTTP error {e.response.status_code} for {e.request.method} with: {e.args}",
extra={"request_data": data, "request_params": params},
)
raise
except httpx.TimeoutException as e:
logger.error(
f"Timeout error for {e.request.method} {e.request.url}",
extra={"request_data": data, "request_params": params},
)
raise
except httpx.RequestError as e:
logger.error(
f"Request error for {e.request.method} {e.request.url}: {str(e)}",
extra={"request_data": data, "request_params": params},
)
raise

@retry_on_502
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=4, max=10))
async def request_async(
self,
method: str,
Expand All @@ -128,10 +101,28 @@ async def request_async(
response.raise_for_status()
return response
except httpx.HTTPStatusError as e:
print(f"HTTP Error {e.response.status_code} for {e.request.url}: {e.response.text}")
if e.response.status_code == 422:
# update the error message to include the validation errors
e.args = (f"{e.args[0]}: {e.response.json()}",)
logger.error(
f"HTTP error {e.response.status_code} for {e.request.method} with: {e.args}",
extra={"request_data": data, "request_params": params},
)
raise
except httpx.TimeoutException as e:
logger.error(
f"Timeout error for {e.request.method} {e.request.url}",
extra={"request_data": data, "request_params": params},
)
raise
except httpx.RequestError as e:
logger.error(
f"Request error for {e.request.method} {e.request.url}: {str(e)}",
extra={"request_data": data, "request_params": params},
)
raise

@retry_on_502
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=4, max=10))
def stream_request(
self,
method: str,
Expand All @@ -151,10 +142,28 @@ def stream_request(
for chunk in response.iter_bytes(chunk_size):
yield parse_event_data(chunk)
except httpx.HTTPStatusError as e:
print(f"HTTP Error {e.response.status_code} for {e.request.url}: {e.response.text}")
if e.response.status_code == 422:
# update the error message to include the validation errors
e.args = (f"{e.args[0]}: {e.response.json()}",)
logger.error(
f"HTTP error {e.response.status_code} for {e.request.method} with: {e.args}",
extra={"request_data": data, "request_params": params},
)
raise
except httpx.TimeoutException as e:
logger.error(
f"Timeout error for {e.request.method} {e.request.url}",
extra={"request_data": data, "request_params": params},
)
raise
except httpx.RequestError as e:
logger.error(
f"Request error for {e.request.method} {e.request.url}: {str(e)}",
extra={"request_data": data, "request_params": params},
)
raise

@retry_on_502
@retry(stop=stop_after_attempt(3), wait=wait_exponential(min=4, max=10))
async def stream_request_async(
self,
method: str,
Expand All @@ -174,7 +183,25 @@ async def stream_request_async(
async for chunk in response.aiter_bytes(chunk_size):
yield parse_event_data(chunk)
except httpx.HTTPStatusError as e:
print(f"HTTP Error {e.response.status_code} for {e.request.url}: {e.response.text}")
if e.response.status_code == 422:
# update the error message to include the validation errors
e.args = (f"{e.args[0]}: {e.response.json()}",)
logger.error(
f"HTTP error {e.response.status_code} for {e.request.method} with: {e.args}",
extra={"request_data": data, "request_params": params},
)
raise
except httpx.TimeoutException as e:
logger.error(
f"Timeout error for {e.request.method} {e.request.url}",
extra={"request_data": data, "request_params": params},
)
raise
except httpx.RequestError as e:
logger.error(
f"Request error for {e.request.method} {e.request.url}: {str(e)}",
extra={"request_data": data, "request_params": params},
)
raise

def close(self):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "parea-ai"
packages = [{ include = "parea" }]
version = "0.2.209"
version = "0.2.210"
description = "Parea python sdk"
readme = "README.md"
authors = ["joel-parea-ai <[email protected]>"]
Expand Down

0 comments on commit 44ecbf2

Please sign in to comment.