Skip to content

Commit

Permalink
Merge pull request #117 from microsoft/python
Browse files Browse the repository at this point in the history
serverless + tracer
  • Loading branch information
sethjuarez authored Oct 29, 2024
2 parents 3f3ab2d + 5c4a5b3 commit 0aa6b2a
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 30 deletions.
95 changes: 91 additions & 4 deletions runtime/prompty/prompty/serverless/executor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import azure.identity
import importlib.metadata
from typing import Iterator
from azure.core.credentials import AzureKeyCredential
from azure.ai.inference import (
ChatCompletionsClient,
EmbeddingsClient,
)

from azure.ai.inference.aio import (
ChatCompletionsClient as AsyncChatCompletionsClient,
EmbeddingsClient as AsyncEmbeddingsClient,
)
from azure.ai.inference.models import (
StreamingChatCompletions,
AsyncStreamingChatCompletions,
Expand All @@ -24,10 +30,18 @@ class ServerlessExecutor(Invoker):
def __init__(self, prompty: Prompty) -> None:
super().__init__(prompty)

# serverless configuration
self.endpoint = self.prompty.model.configuration["endpoint"]
self.model = self.prompty.model.configuration["model"]
self.key = self.prompty.model.configuration["key"]

# no key, use default credentials
if "key" not in self.kwargs:
self.credential = azure.identity.DefaultAzureCredential(
exclude_shared_token_cache_credential=True
)
else:
self.credential = AzureKeyCredential(
self.prompty.model.configuration["key"]
)

# api type
self.api = self.prompty.model.api
Expand Down Expand Up @@ -64,7 +78,7 @@ def invoke(self, data: any) -> any:

cargs = {
"endpoint": self.endpoint,
"credential": AzureKeyCredential(self.key),
"credential": self.credential,
}

if self.api == "chat":
Expand Down Expand Up @@ -150,4 +164,77 @@ async def invoke_async(self, data: str) -> str:
str
The parsed data
"""
return self.invoke(data)
cargs = {
"endpoint": self.endpoint,
"credential": self.credential,
}

if self.api == "chat":
with Tracer.start("ChatCompletionsClient") as trace:
trace("type", "LLM")
trace("signature", "azure.ai.inference.aio.ChatCompletionsClient.ctor")
trace(
"description", "Azure Unified Inference SDK Async Chat Completions Client"
)
trace("inputs", cargs)
client = AsyncChatCompletionsClient(
user_agent=f"prompty/{VERSION}",
**cargs,
)
trace("result", client)

with Tracer.start("complete") as trace:
trace("type", "LLM")
trace("signature", "azure.ai.inference.ChatCompletionsClient.complete")
trace(
"description", "Azure Unified Inference SDK Async Chat Completions Client"
)
eargs = {
"model": self.model,
"messages": data if isinstance(data, list) else [data],
**self.prompty.model.parameters,
}
trace("inputs", eargs)
r = await client.complete(**eargs)
trace("result", r)

response = self._response(r)

elif self.api == "completion":
raise NotImplementedError(
"Serverless Completions API is not implemented yet"
)

elif self.api == "embedding":
with Tracer.start("EmbeddingsClient") as trace:
trace("type", "LLM")
trace("signature", "azure.ai.inference.aio.EmbeddingsClient.ctor")
trace("description", "Azure Unified Inference SDK Async Embeddings Client")
trace("inputs", cargs)
client = AsyncEmbeddingsClient(
user_agent=f"prompty/{VERSION}",
**cargs,
)
trace("result", client)

with Tracer.start("complete") as trace:
trace("type", "LLM")
trace("signature", "azure.ai.inference.ChatCompletionsClient.complete")
trace(
"description", "Azure Unified Inference SDK Chat Completions Client"
)
eargs = {
"model": self.model,
"input": data if isinstance(data, list) else [data],
**self.prompty.model.parameters,
}
trace("inputs", eargs)
r = await client.complete(**eargs)
trace("result", r)

response = self._response(r)

elif self.api == "image":
raise NotImplementedError("Azure OpenAI Image API is not implemented yet")

return response
41 changes: 38 additions & 3 deletions runtime/prompty/prompty/serverless/processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Iterator
from typing import AsyncIterator, Iterator
from ..invoker import Invoker, InvokerFactory
from ..core import Prompty, PromptyStream, ToolCall
from ..core import AsyncPromptyStream, Prompty, PromptyStream, ToolCall

from azure.ai.inference.models import ChatCompletions, EmbeddingsResult

Expand Down Expand Up @@ -75,4 +75,39 @@ async def invoke_async(self, data: str) -> str:
str
The parsed data
"""
return self.invoke(data)
if isinstance(data, ChatCompletions):
response = data.choices[0].message
# tool calls available in response
if response.tool_calls:
return [
ToolCall(
id=tool_call.id,
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
for tool_call in response.tool_calls
]
else:
return response.content

elif isinstance(data, EmbeddingsResult):
if len(data.data) == 0:
raise ValueError("Invalid data")
elif len(data.data) == 1:
return data.data[0].embedding
else:
return [item.embedding for item in data.data]
elif isinstance(data, AsyncIterator):

async def generator():
async for chunk in data:
if (
len(chunk.choices) == 1
and chunk.choices[0].delta.content != None
):
content = chunk.choices[0].delta.content
yield content

return AsyncPromptyStream("ServerlessProcessor", generator())
else:
return data
36 changes: 16 additions & 20 deletions runtime/prompty/prompty/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def _name(func: Callable, args):
if core_invoker:
name = type(args[0]).__name__
if signature.endswith("async"):
signature = f"{args[0].__module__}.{args[0].__class__.__name__}.invoke_async"
signature = (
f"{args[0].__module__}.{args[0].__class__.__name__}.invoke_async"
)
else:
signature = f"{args[0].__module__}.{args[0].__class__.__name__}.invoke"
else:
Expand All @@ -116,20 +118,19 @@ def _results(result: Any) -> dict:


def _trace_sync(
func: Callable = None, *, description: str = None, itemtype: str = None
func: Callable = None, **okwargs: Any
) -> Callable:
description = description or ""

@wraps(func)
def wrapper(*args, **kwargs):
name, signature = _name(func, args)
with Tracer.start(name) as trace:
trace("signature", signature)
if description and description != "":
trace("description", description)

if itemtype and itemtype != "":
trace("type", itemtype)
# support arbitrary keyword
# arguments for trace decorator
for k, v in okwargs.items():
trace(k, to_dict(v))

inputs = _inputs(func, args, kwargs)
trace("inputs", inputs)
Expand Down Expand Up @@ -161,20 +162,19 @@ def wrapper(*args, **kwargs):


def _trace_async(
func: Callable = None, *, description: str = None, itemtype: str = None
func: Callable = None, **okwargs: Any
) -> Callable:
description = description or ""

@wraps(func)
async def wrapper(*args, **kwargs):
name, signature = _name(func, args)
with Tracer.start(name) as trace:
trace("signature", signature)
if description and description != "":
trace("description", description)

if itemtype and itemtype != "":
trace("type", itemtype)
# support arbitrary keyword
# arguments for trace decorator
for k, v in okwargs.items():
trace(k, to_dict(v))

inputs = _inputs(func, args, kwargs)
trace("inputs", inputs)
Expand Down Expand Up @@ -204,15 +204,11 @@ async def wrapper(*args, **kwargs):
return wrapper


def trace(
func: Callable = None, *, description: str = None, itemtype: str = None
) -> Callable:
def trace(func: Callable = None, **kwargs: Any) -> Callable:
if func is None:
return partial(trace, description=description, itemtype=itemtype)

return partial(trace, **kwargs)
wrapped_method = _trace_async if inspect.iscoroutinefunction(func) else _trace_sync

return wrapped_method(func, description=description, itemtype=itemtype)
return wrapped_method(func, **kwargs)


class PromptyTracer:
Expand Down
2 changes: 1 addition & 1 deletion runtime/prompty/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies = [
[project.optional-dependencies]
azure = ["azure-identity>=1.17.1","openai>=1.35.10"]
openai = ["openai>=1.35.10"]
serverless = ["azure-ai-inference>=1.0.0b3"]
serverless = ["azure-identity>=1.17.1","azure-ai-inference>=1.0.0b3"]


[tool.pdm]
Expand Down
4 changes: 2 additions & 2 deletions runtime/prompty/tests/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ async def test_function_calling_async():
# need to add trace attribute to
# materialize stream into the function
# trace decorator
@trace
@trace(streaming=True, other="test")
def test_streaming():
result = prompty.execute(
"prompts/streaming.prompty",
Expand All @@ -254,7 +254,7 @@ def test_streaming():


@pytest.mark.asyncio
@trace
@trace(streaming=True)
async def test_streaming_async():
result = await prompty.execute_async(
"prompts/streaming.prompty",
Expand Down

0 comments on commit 0aa6b2a

Please sign in to comment.