Skip to content

Commit

Permalink
add method to context manager to add to span to session
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy committed Nov 1, 2024
1 parent 204c063 commit b464529
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 88 deletions.
156 changes: 71 additions & 85 deletions src/phoenix/server/api/helpers/playground_spans.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from traceback import format_exc
from types import TracebackType
from typing import (
TYPE_CHECKING,
Any,
Iterable,
Iterator,
Expand All @@ -26,32 +27,30 @@
)
from opentelemetry.sdk.trace.id_generator import RandomIdGenerator as DefaultOTelIDGenerator
from opentelemetry.trace import StatusCode
from sqlalchemy import insert, select
from sqlalchemy.ext.asyncio import AsyncSession
from strawberry.scalars import JSON as JSONScalarType
from typing_extensions import Self, TypeAlias, assert_never

from phoenix.datetime_utils import local_now, normalize_datetime
from phoenix.db import models
from phoenix.server.api.input_types.ChatCompletionInput import ChatCompletionInput
from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
FinishedChatCompletion,
TextChunk,
ToolCallChunk,
)
from phoenix.server.api.types.Span import to_gql_span
from phoenix.server.types import DbSessionFactory
from phoenix.trace.attributes import unflatten
from phoenix.trace.schemas import (
SpanEvent,
SpanException,
)
from phoenix.utilities.json import jsonify

PLAYGROUND_PROJECT_NAME = "playground"
if TYPE_CHECKING:
from phoenix.server.api.input_types.ChatCompletionInput import ChatCompletionInput
from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
TextChunk,
ToolCallChunk,
)


ChatCompletionMessage: TypeAlias = tuple[
ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]
"ChatCompletionMessageRole", str, Optional[str], Optional[list[str]]
]
ToolCallID: TypeAlias = str

Expand All @@ -64,10 +63,9 @@ class streaming_llm_span:
def __init__(
self,
*,
input: ChatCompletionInput,
input: "ChatCompletionInput",
messages: list[ChatCompletionMessage],
invocation_parameters: dict[str, Any],
db: DbSessionFactory,
attributes: Optional[dict[str, Any]] = None,
) -> None:
self._input = input
Expand All @@ -82,14 +80,14 @@ def __init__(
_input_value_and_mime_type(input),
)
)
self._db = db
self._events: list[SpanEvent] = []
self._start_time: datetime
self._response_chunks: list[Union[TextChunk, ToolCallChunk]] = []
self._text_chunks: list[TextChunk] = []
self._tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]] = defaultdict(list)
self._finished_chat_completion: FinishedChatCompletion
self._project_id: int
self._end_time: datetime
self._response_chunks: list[Union["TextChunk", "ToolCallChunk"]] = []
self._text_chunks: list["TextChunk"] = []
self._tool_call_chunks: defaultdict[ToolCallID, list["ToolCallChunk"]] = defaultdict(list)
self._status_code: StatusCode
self._status_message: str

async def __aenter__(self) -> Self:
self._start_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
Expand All @@ -101,16 +99,16 @@ async def __aexit__(
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> bool:
end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
status_code = StatusCode.OK
status_message = ""
self._end_time = cast(datetime, normalize_datetime(dt=local_now(), tz=timezone.utc))
self._status_code = StatusCode.OK
self._status_message = ""
if exc_type is not None:
status_code = StatusCode.ERROR
status_message = str(exc_value)
self._status_code = StatusCode.ERROR
self._status_message = str(exc_value)
self._events.append(
SpanException(
timestamp=end_time,
message=status_message,
timestamp=self._end_time,
message=self._status_message,
exception_type=type(exc_value).__name__,
exception_escaped=False,
exception_stacktrace=format_exc(),
Expand All @@ -123,63 +121,55 @@ async def __aexit__(
_llm_output_messages(self._text_chunks, self._tool_call_chunks),
)
)
return True

def add_to_session(
self,
session: AsyncSession,
project_id: int,
) -> models.Span:
prompt_tokens = self._attributes.get(LLM_TOKEN_COUNT_PROMPT, 0)
completion_tokens = self._attributes.get(LLM_TOKEN_COUNT_COMPLETION, 0)
trace_id = _generate_trace_id()
span_id = _generate_span_id()
async with self._db() as session:
if (
project_id := await session.scalar(
select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
)
) is None:
project_id = await session.scalar(
insert(models.Project)
.returning(models.Project.id)
.values(
name=PLAYGROUND_PROJECT_NAME,
description="Traces from prompt playground",
)
)
trace = models.Trace(
project_rowid=project_id,
trace_id=trace_id,
start_time=self._start_time,
end_time=end_time,
)
span = models.Span(
trace_rowid=trace.id,
span_id=span_id,
parent_id=None,
name="ChatCompletion",
span_kind=LLM,
start_time=self._start_time,
end_time=end_time,
attributes=unflatten(self._attributes.items()),
events=[_serialize_event(event) for event in self._events],
status_code=status_code.name,
status_message=status_message,
cumulative_error_count=int(status_code is StatusCode.ERROR),
cumulative_llm_token_count_prompt=prompt_tokens,
cumulative_llm_token_count_completion=completion_tokens,
llm_token_count_prompt=prompt_tokens,
llm_token_count_completion=completion_tokens,
trace=trace,
)
session.add(trace)
session.add(span)
await session.flush()
self._project_id = project_id
self._finished_chat_completion = FinishedChatCompletion(
span=to_gql_span(span),
error_message=status_message if status_code is StatusCode.ERROR else None,
trace = models.Trace(
project_rowid=project_id,
trace_id=trace_id,
start_time=self._start_time,
end_time=self._end_time,
)
return True
span = models.Span(
trace_rowid=trace.id,
span_id=span_id,
parent_id=None,
name="ChatCompletion",
span_kind=LLM,
start_time=self._start_time,
end_time=self._end_time,
attributes=unflatten(self._attributes.items()),
events=[_serialize_event(event) for event in self._events],
status_code=self._status_code.name,
status_message=self._status_message,
cumulative_error_count=int(self._status_code is StatusCode.ERROR),
cumulative_llm_token_count_prompt=prompt_tokens,
cumulative_llm_token_count_completion=completion_tokens,
llm_token_count_prompt=prompt_tokens,
llm_token_count_completion=completion_tokens,
trace=trace,
)
session.add(trace)
session.add(span)
return span

def set_attributes(self, attributes: Mapping[str, Any]) -> None:
self._attributes.update(attributes)

def add_response_chunk(self, chunk: Union[TextChunk, ToolCallChunk]) -> None:
def add_response_chunk(self, chunk: Union["TextChunk", "ToolCallChunk"]) -> None:
from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
TextChunk,
ToolCallChunk,
)

self._response_chunks.append(chunk)
if isinstance(chunk, TextChunk):
self._text_chunks.append(chunk)
Expand All @@ -189,12 +179,8 @@ def add_response_chunk(self, chunk: Union[TextChunk, ToolCallChunk]) -> None:
assert_never(chunk)

@property
def finished_chat_completion(self) -> FinishedChatCompletion:
return self._finished_chat_completion

@property
def project_id(self) -> int:
return self._project_id
def error_message(self) -> Optional[str]:
return self._status_message if self._status_code is StatusCode.ERROR else None


def _llm_span_kind() -> Iterator[tuple[str, Any]]:
Expand All @@ -214,7 +200,7 @@ def _llm_tools(tools: list[JSONScalarType]) -> Iterator[tuple[str, Any]]:
yield f"{LLM_TOOLS}.{tool_index}.{TOOL_JSON_SCHEMA}", json.dumps(tool)


def _input_value_and_mime_type(input: ChatCompletionInput) -> Iterator[tuple[str, Any]]:
def _input_value_and_mime_type(input: "ChatCompletionInput") -> Iterator[tuple[str, Any]]:
assert (api_key := "api_key") in (input_data := jsonify(input))
disallowed_keys = {"api_key", "invocation_parameters"}
input_data = {k: v for k, v in input_data.items() if k not in disallowed_keys}
Expand All @@ -230,7 +216,7 @@ def _output_value_and_mime_type(output: Any) -> Iterator[tuple[str, Any]]:

def _llm_input_messages(
messages: Iterable[
tuple[ChatCompletionMessageRole, str, Optional[str], Optional[list[JSONScalarType]]]
tuple["ChatCompletionMessageRole", str, Optional[str], Optional[list[JSONScalarType]]]
],
) -> Iterator[tuple[str, Any]]:
for i, (role, content, _tool_call_id, tool_calls) in enumerate(messages):
Expand All @@ -250,8 +236,8 @@ def _llm_input_messages(


def _llm_output_messages(
text_chunks: list[TextChunk],
tool_call_chunks: defaultdict[ToolCallID, list[ToolCallChunk]],
text_chunks: list["TextChunk"],
tool_call_chunks: defaultdict[ToolCallID, list["ToolCallChunk"]],
) -> Iterator[tuple[str, Any]]:
yield f"{LLM_OUTPUT_MESSAGES}.0.{MESSAGE_ROLE}", "assistant"
if content := "".join(chunk.content for chunk in text_chunks):
Expand Down
28 changes: 25 additions & 3 deletions src/phoenix/server/api/subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
)

import strawberry
from sqlalchemy import insert, select
from strawberry.types import Info
from typing_extensions import TypeAlias, assert_never

from phoenix.db import models
from phoenix.server.api.context import Context
from phoenix.server.api.exceptions import BadRequest
from phoenix.server.api.helpers.playground_clients import initialize_playground_clients
Expand All @@ -20,7 +22,9 @@
from phoenix.server.api.types.ChatCompletionMessageRole import ChatCompletionMessageRole
from phoenix.server.api.types.ChatCompletionSubscriptionPayload import (
ChatCompletionSubscriptionPayload,
FinishedChatCompletion,
)
from phoenix.server.api.types.Span import to_gql_span
from phoenix.server.api.types.TemplateLanguage import TemplateLanguage
from phoenix.server.dml_event import SpanInsertEvent
from phoenix.utilities.template_formatters import (
Expand All @@ -34,6 +38,7 @@
ChatCompletionMessage: TypeAlias = tuple[
ChatCompletionMessageRole, str, Optional[str], Optional[list[str]]
]
PLAYGROUND_PROJECT_NAME = "playground"


@strawberry.type
Expand Down Expand Up @@ -71,16 +76,33 @@ async def chat_completion(
input=input,
messages=messages,
invocation_parameters=invocation_parameters,
db=info.context.db,
attributes=attributes,
) as span:
async for chunk in llm_client.chat_completion_create(
messages=messages, tools=input.tools or [], **invocation_parameters
):
span.add_response_chunk(chunk)
yield chunk
yield span.finished_chat_completion
info.context.event_queue.put(SpanInsertEvent(ids=(span.project_id,)))
async with info.context.db() as session:
if (
playground_project_id := await session.scalar(
select(models.Project.id).where(models.Project.name == PLAYGROUND_PROJECT_NAME)
)
) is None:
playground_project_id = await session.scalar(
insert(models.Project)
.returning(models.Project.id)
.values(
name=PLAYGROUND_PROJECT_NAME,
description="Traces from prompt playground",
)
)
db_span = span.add_to_session(session, playground_project_id)
await session.flush()
yield FinishedChatCompletion(
span=to_gql_span(db_span), error_message=span.error_message
)
info.context.event_queue.put(SpanInsertEvent(ids=(playground_project_id,)))


def _formatted_messages(
Expand Down

0 comments on commit b464529

Please sign in to comment.