diff --git a/runtime/prompty/prompty/core.py b/runtime/prompty/prompty/core.py index 2704661..584e752 100644 --- a/runtime/prompty/prompty/core.py +++ b/runtime/prompty/prompty/core.py @@ -8,7 +8,7 @@ from pathlib import Path from .tracer import Tracer, trace, to_dict from pydantic import BaseModel, Field, FilePath -from typing import Iterator, List, Literal, Dict, Callable, Set +from typing import AsyncIterator, Iterator, List, Literal, Dict, Callable, Set class PropertySettings(BaseModel): @@ -479,3 +479,33 @@ def __next__(self): trace("items", [to_dict(s) for s in self.items]) raise StopIteration + + +class AsyncPromptyStream(AsyncIterator): + """AsyncPromptyStream class to iterate over LLM stream. + Necessary for Prompty to handle streaming data when tracing.""" + + def __init__(self, name: str, iterator: AsyncIterator): + self.name = name + self.iterator = iterator + self.items: List[any] = [] + self.__name__ = "AsyncPromptyStream" + + def __aiter__(self): + return self + + async def __anext__(self): + try: + # enumerate but add to list + o = await self.iterator.__anext__() + self.items.append(o) + return o + + except StopIteration: + # StopIteration is raised + # contents are exhausted + if len(self.items) > 0: + with Tracer.start(f"{self.name}.AsyncPromptyStream") as trace: + trace("items", [to_dict(s) for s in self.items]) + + raise StopIteration diff --git a/runtime/prompty/prompty/tracer.py b/runtime/prompty/prompty/tracer.py index 7bf0e25..d78244e 100644 --- a/runtime/prompty/prompty/tracer.py +++ b/runtime/prompty/prompty/tracer.py @@ -46,6 +46,8 @@ def to_dict(obj: Any) -> Dict[str, Any]: # safe PromptyStream obj serialization elif type(obj).__name__ == "PromptyStream": return "PromptyStream" + elif type(obj).__name__ == "AsyncPromptyStream": + return "AsyncPromptyStream" # pydantic models have their own json serialization elif isinstance(obj, BaseModel): return obj.model_dump()