Skip to content

Commit

Permalink
Merge pull request #1051 from parea-ai/PAI-1464-make-trace-decorator-…
Browse files Browse the repository at this point in the history
…work-with-iterator-responses

Pai 1464 make trace decorator work with iterator responses
  • Loading branch information
joschkabraun authored Aug 13, 2024
2 parents 8151f6a + 9d46118 commit 0517b1f
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 58 deletions.
74 changes: 46 additions & 28 deletions cookbook/instructor/instructor_streaming.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,81 @@
import os

import anthropic
import instructor
from dotenv import load_dotenv
from openai import AsyncOpenAI
from pydantic import BaseModel

from parea import Parea
from parea import Parea, trace

load_dotenv()

client = AsyncOpenAI()
oai_aclient = AsyncOpenAI()
ant_client = anthropic.AsyncClient()

p = Parea(api_key=os.getenv("PAREA_API_KEY"))
p.wrap_openai_client(client, "instructor")

client = instructor.from_openai(client)

p.wrap_openai_client(oai_aclient, "instructor")
p.wrap_anthropic_client(ant_client)

from pydantic import BaseModel
oai_aclient = instructor.from_openai(oai_aclient)
ant_client = instructor.from_anthropic(ant_client)


class UserDetail(BaseModel):
name: str
age: int
age: str


async def main():
user = client.completions.create_partial(
model="gpt-3.5-turbo",
@trace
async def ainner_main():
user = oai_aclient.completions.create_partial(
model="gpt-4o-mini",
max_tokens=1024,
max_retries=3,
messages=[
{
"role": "user",
"content": "Please crea a user",
"content": "Please create a user",
}
],
response_model=UserDetail,
)
# print(user)
async for u in user:
return user


async def amain():
resp = await ainner_main()
async for u in resp:
print(u)

# user2 = client.completions.create_partial(
# model="gpt-3.5-turbo",
# max_tokens=1024,
# max_retries=3,
# messages=[
# {
# "role": "user",
# "content": "Please crea a user",
# }
# ],
# response_model=UserDetail,
# )
# async for u in user2:
# print(u)

@trace
def inner_main():
user = ant_client.completions.create_partial(
model="claude-3-5-sonnet-20240620",
max_tokens=1024,
max_retries=3,
messages=[
{
"role": "user",
"content": "Please create a user",
}
],
response_model=UserDetail,
)
return user


def main():
resp = inner_main()
for u in resp:
print(u)


if __name__ == "__main__":
import asyncio

asyncio.run(main())
asyncio.run(amain())

main()
8 changes: 7 additions & 1 deletion cookbook/openai/tracing_open_ai_streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,15 @@


@trace
def call_openai_stream(data: dict):
def _call_openai_stream(data: dict):
data["stream"] = True
stream = client.chat.completions.create(**data)
for chunk in stream:
yield chunk


def call_openai_stream(data: dict):
stream = _call_openai_stream(data)
for chunk in stream:
print(chunk.choices[0].delta or "")

Expand Down
34 changes: 19 additions & 15 deletions parea/utils/trace_integrations/instructor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Callable, Mapping, Tuple
from typing import Any, Callable, List, Mapping, Tuple

import contextvars
import logging
from json import JSONDecodeError

from instructor.retry import InstructorRetryException
Expand All @@ -12,6 +13,9 @@
from parea.schemas import EvaluationResult, UpdateLog
from parea.utils.trace_integrations.wrapt_utils import CopyableFunctionWrapper
from parea.utils.trace_utils import logger_update_record, trace_data, trace_insert
from parea.utils.universal_encoder import json_dumps

logger = logging.getLogger()

instructor_trace_id = contextvars.ContextVar("instructor_trace_id", default="")
instructor_val_err_count = contextvars.ContextVar("instructor_val_err_count", default=0)
Expand Down Expand Up @@ -50,14 +54,11 @@ def report_instructor_validation_errors() -> None:
score=instructor_val_err_count.get(),
reason=reason,
)
last_child_trace_id = trace_data.get()[instructor_trace_id.get()].children[-1]
trace_insert(
{
"scores": [instructor_score],
"configuration": trace_data.get()[last_child_trace_id].configuration,
},
instructor_trace_id.get(),
)
trace_update_dict = {"scores": [instructor_score]}
if children := trace_data.get()[instructor_trace_id.get()].children:
last_child_trace_id = children[-1]
trace_update_dict["configuration"] = trace_data.get()[last_child_trace_id].configuration
trace_insert(trace_update_dict, instructor_trace_id.get())
instructor_trace_id.set("")
instructor_val_err_count.set(0)
instructor_val_errs.set([])
Expand All @@ -82,12 +83,15 @@ def __call__(
trace_name = "instructor"
if "response_model" in kwargs and kwargs["response_model"] and hasattr(kwargs["response_model"], "__name__"):
trace_name = kwargs["response_model"].__name__
return trace(
name=trace_name,
overwrite_trace_id=trace_id,
overwrite_inputs=inputs,
metadata=metadata,
)(

def fn_transform_generator_outputs(items: List) -> str:
try:
return json_dumps(items[-1])
except Exception as e:
logger.warning(f"Failed to serialize generator output: {e}", exc_info=e)
return ""

return trace(name=trace_name, overwrite_trace_id=trace_id, overwrite_inputs=inputs, metadata=metadata, fn_transform_generator_outputs=fn_transform_generator_outputs)(
wrapped
)(*args, **kwargs)
except (InstructorRetryException, ValidationError, JSONDecodeError) as e:
Expand Down
51 changes: 47 additions & 4 deletions parea/utils/trace_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def trace(
overwrite_trace_id: Optional[str] = None,
overwrite_inputs: Optional[Dict[str, Any]] = None,
log_sample_rate: Optional[float] = 1.0,
fn_transform_generator_outputs: Callable[[List[Any]], str] = None,
):
def init_trace(func_name, _parea_target_field, args, kwargs, func) -> Tuple[str, datetime, contextvars.Token]:
start_time = timezone_aware_now()
Expand Down Expand Up @@ -258,24 +259,60 @@ def cleanup_trace(trace_id: str, start_time: datetime, context_token: contextvar
thread_eval_funcs_then_log(trace_id, eval_funcs)
trace_context.reset(context_token)

def _handle_iterator_cleanup(items, trace_id, start_time, context_token):
if fn_transform_generator_outputs:
result = fn_transform_generator_outputs(items)
elif all(isinstance(item, str) for item in items):
result = "".join(items)
else:
result = ""
if not is_logging_disabled() and not log_omit_outputs:
fill_trace_data(trace_id, {"result": result}, UpdateTraceScenario.RESULT)

cleanup_trace(trace_id, start_time, context_token)

async def _wrap_async_iterator(iterator, trace_id, start_time, context_token):
items = []
try:
async for item in iterator:
items.append(item)
yield item
finally:
_handle_iterator_cleanup(items, trace_id, start_time, context_token)

def _wrap_sync_iterator(iterator, trace_id, start_time, context_token):
items = []
try:
for item in iterator:
items.append(item)
yield item
finally:
_handle_iterator_cleanup(items, trace_id, start_time, context_token)

def decorator(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
_parea_target_field = kwargs.pop("_parea_target_field", None)
trace_id, start_time, context_token = init_trace(func.__name__, _parea_target_field, args, kwargs, func)
output_as_list = check_multiple_return_values(func)
result = None
try:
result = await func(*args, **kwargs)
if not is_logging_disabled() and not log_omit_outputs:
fill_trace_data(trace_id, {"result": result, "output_as_list": output_as_list, "eval_funcs_names": eval_funcs_names}, UpdateTraceScenario.RESULT)
return result
except Exception as e:
logger.error(f"Error occurred in function {func.__name__}, {e}")
fill_trace_data(trace_id, {"error": traceback.format_exc()}, UpdateTraceScenario.ERROR)
raise e
finally:
try:
cleanup_trace(trace_id, start_time, context_token)
if inspect.isasyncgen(result):
return _wrap_async_iterator(result, trace_id, start_time, context_token)
else:
cleanup_trace(trace_id, start_time, context_token)
# to not swallow any exceptions
if result is not None:
return result
except Exception as e:
logger.debug(f"Error occurred cleaning up trace for function {func.__name__}, {e}", exc_info=e)

Expand All @@ -284,18 +321,24 @@ def wrapper(*args, **kwargs):
_parea_target_field = kwargs.pop("_parea_target_field", None)
trace_id, start_time, context_token = init_trace(func.__name__, _parea_target_field, args, kwargs, func)
output_as_list = check_multiple_return_values(func)
result = None
try:
result = func(*args, **kwargs)
if not is_logging_disabled() and not log_omit_outputs:
fill_trace_data(trace_id, {"result": result, "output_as_list": output_as_list, "eval_funcs_names": eval_funcs_names}, UpdateTraceScenario.RESULT)
return result
except Exception as e:
logger.error(f"Error occurred in function {func.__name__}, {e}")
fill_trace_data(trace_id, {"error": traceback.format_exc()}, UpdateTraceScenario.ERROR)
raise e
finally:
try:
cleanup_trace(trace_id, start_time, context_token)
if inspect.isgenerator(result):
return _wrap_sync_iterator(result, trace_id, start_time, context_token)
else:
cleanup_trace(trace_id, start_time, context_token)
# to not swallow any exceptions
if result is not None:
return result
except Exception as e:
logger.debug(f"Error occurred cleaning up trace for function {func.__name__}, {e}", exc_info=e)

Expand Down
9 changes: 5 additions & 4 deletions parea/wrapper/anthropic/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from datetime import datetime

from anthropic import AsyncMessageStreamManager, AsyncStream, Client, MessageStreamManager, Stream
from anthropic.types import ContentBlockDeltaEvent, Message, MessageDeltaEvent, MessageStartEvent, TextBlock
from anthropic.types import ContentBlockDeltaEvent, InputJSONDelta, Message, MessageDeltaEvent, MessageStartEvent, TextBlock, ToolUseBlock

from parea.cache.cache import Cache
from parea.helpers import timezone_aware_now
Expand Down Expand Up @@ -43,8 +43,6 @@ def init(self, log: Callable, cache: Cache, client: Client):
def resolver(trace_id: str, _args: Sequence[Any], kwargs: Dict[str, Any], response: Optional[Message]) -> Optional[Any]:
if response:
if len(response.content) > 1:
from anthropic.types.beta.tools import ToolUseBlock

output_list = []
for content in response.content:
if isinstance(content, TextBlock):
Expand Down Expand Up @@ -185,7 +183,10 @@ def _update_accumulator_streaming(accumulator, info_from_response, chunk):
if isinstance(chunk, MessageStartEvent):
info_from_response["input_tokens"] = chunk.message.usage.input_tokens
elif isinstance(chunk, ContentBlockDeltaEvent):
accumulator["content"].append(chunk.delta.text)
if isinstance(chunk.delta, InputJSONDelta):
accumulator["content"].append(chunk.delta.partial_json)
else:
accumulator["content"].append(chunk.delta.text)
if not info_from_response.get("first_token_timestamp"):
info_from_response["first_token_timestamp"] = timezone_aware_now()
elif isinstance(chunk, MessageDeltaEvent):
Expand Down
10 changes: 5 additions & 5 deletions parea/wrapper/anthropic/stream_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from types import TracebackType
from typing import Callable

from anthropic import AsyncMessageStreamManager, MessageStreamManager, Stream
from anthropic import AsyncMessageStreamManager, AsyncStream, MessageStreamManager, Stream
from anthropic.types import Message


Expand All @@ -16,8 +16,8 @@ def __init__(self, stream: Stream, accumulator, info_from_response, update_accum
self._info_from_response = info_from_response

def __getattr__(self, attr):
# delegate attribute access to the original async_stream
return getattr(self._async_stream, attr)
# delegate attribute access to the original stream
return getattr(self._stream, attr) if hasattr(self._stream, attr) else None

def __iter__(self):
for chunk in self._stream:
Expand All @@ -28,7 +28,7 @@ def __iter__(self):


class AnthropicAsyncStreamWrapper:
def __init__(self, stream: Stream, accumulator, info_from_response, update_accumulator_streaming, final_processing_and_logging):
def __init__(self, stream: AsyncStream, accumulator, info_from_response, update_accumulator_streaming, final_processing_and_logging):
self._stream = stream
self._final_processing_and_logging = final_processing_and_logging
self._update_accumulator_streaming = update_accumulator_streaming
Expand All @@ -37,7 +37,7 @@ def __init__(self, stream: Stream, accumulator, info_from_response, update_accum

def __getattr__(self, attr):
# delegate attribute access to the original async_stream
return getattr(self._async_stream, attr)
return getattr(self._stream, attr) if hasattr(self._stream, attr) else None

async def __aiter__(self):
async for chunk in self._stream:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
[tool.poetry]
name = "parea-ai"
packages = [{ include = "parea" }]
version = "0.2.201"
version = "0.2.202"
description = "Parea python sdk"
readme = "README.md"
authors = ["joel-parea-ai <[email protected]>"]
Expand Down

0 comments on commit 0517b1f

Please sign in to comment.