Skip to content

Commit

Permalink
added credential handling for serverless as well as proper async exec…
Browse files Browse the repository at this point in the history
…ution
  • Loading branch information
sethjuarez committed Oct 29, 2024
1 parent 3f3ab2d commit 5c4a5b3
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 30 deletions.
95 changes: 91 additions & 4 deletions runtime/prompty/prompty/serverless/executor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
import azure.identity
import importlib.metadata
from typing import Iterator
from azure.core.credentials import AzureKeyCredential
from azure.ai.inference import (
ChatCompletionsClient,
EmbeddingsClient,
)

from azure.ai.inference.aio import (
ChatCompletionsClient as AsyncChatCompletionsClient,
EmbeddingsClient as AsyncEmbeddingsClient,
)
from azure.ai.inference.models import (
StreamingChatCompletions,
AsyncStreamingChatCompletions,
Expand All @@ -24,10 +30,18 @@ class ServerlessExecutor(Invoker):
def __init__(self, prompty: Prompty) -> None:
super().__init__(prompty)

# serverless configuration
self.endpoint = self.prompty.model.configuration["endpoint"]
self.model = self.prompty.model.configuration["model"]
self.key = self.prompty.model.configuration["key"]

# no key, use default credentials
if "key" not in self.kwargs:
self.credential = azure.identity.DefaultAzureCredential(
exclude_shared_token_cache_credential=True
)
else:
self.credential = AzureKeyCredential(
self.prompty.model.configuration["key"]
)

# api type
self.api = self.prompty.model.api
Expand Down Expand Up @@ -64,7 +78,7 @@ def invoke(self, data: any) -> any:

cargs = {
"endpoint": self.endpoint,
"credential": AzureKeyCredential(self.key),
"credential": self.credential,
}

if self.api == "chat":
Expand Down Expand Up @@ -150,4 +164,77 @@ async def invoke_async(self, data: str) -> str:
str
The parsed data
"""
return self.invoke(data)
cargs = {
"endpoint": self.endpoint,
"credential": self.credential,
}

if self.api == "chat":
with Tracer.start("ChatCompletionsClient") as trace:
trace("type", "LLM")
trace("signature", "azure.ai.inference.aio.ChatCompletionsClient.ctor")
trace(
"description", "Azure Unified Inference SDK Async Chat Completions Client"
)
trace("inputs", cargs)
client = AsyncChatCompletionsClient(
user_agent=f"prompty/{VERSION}",
**cargs,
)
trace("result", client)

with Tracer.start("complete") as trace:
trace("type", "LLM")
trace("signature", "azure.ai.inference.ChatCompletionsClient.complete")
trace(
"description", "Azure Unified Inference SDK Async Chat Completions Client"
)
eargs = {
"model": self.model,
"messages": data if isinstance(data, list) else [data],
**self.prompty.model.parameters,
}
trace("inputs", eargs)
r = await client.complete(**eargs)
trace("result", r)

response = self._response(r)

elif self.api == "completion":
raise NotImplementedError(
"Serverless Completions API is not implemented yet"
)

elif self.api == "embedding":
with Tracer.start("EmbeddingsClient") as trace:
trace("type", "LLM")
trace("signature", "azure.ai.inference.aio.EmbeddingsClient.ctor")
trace("description", "Azure Unified Inference SDK Async Embeddings Client")
trace("inputs", cargs)
client = AsyncEmbeddingsClient(
user_agent=f"prompty/{VERSION}",
**cargs,
)
trace("result", client)

with Tracer.start("complete") as trace:
trace("type", "LLM")
trace("signature", "azure.ai.inference.ChatCompletionsClient.complete")
trace(
"description", "Azure Unified Inference SDK Chat Completions Client"
)
eargs = {
"model": self.model,
"input": data if isinstance(data, list) else [data],
**self.prompty.model.parameters,
}
trace("inputs", eargs)
r = await client.complete(**eargs)
trace("result", r)

response = self._response(r)

elif self.api == "image":
raise NotImplementedError("Azure OpenAI Image API is not implemented yet")

return response
41 changes: 38 additions & 3 deletions runtime/prompty/prompty/serverless/processor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Iterator
from typing import AsyncIterator, Iterator
from ..invoker import Invoker, InvokerFactory
from ..core import Prompty, PromptyStream, ToolCall
from ..core import AsyncPromptyStream, Prompty, PromptyStream, ToolCall

from azure.ai.inference.models import ChatCompletions, EmbeddingsResult

Expand Down Expand Up @@ -75,4 +75,39 @@ async def invoke_async(self, data: str) -> str:
str
The parsed data
"""
return self.invoke(data)
if isinstance(data, ChatCompletions):
response = data.choices[0].message
# tool calls available in response
if response.tool_calls:
return [
ToolCall(
id=tool_call.id,
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)
for tool_call in response.tool_calls
]
else:
return response.content

elif isinstance(data, EmbeddingsResult):
if len(data.data) == 0:
raise ValueError("Invalid data")
elif len(data.data) == 1:
return data.data[0].embedding
else:
return [item.embedding for item in data.data]
elif isinstance(data, AsyncIterator):

async def generator():
async for chunk in data:
if (
len(chunk.choices) == 1
and chunk.choices[0].delta.content != None
):
content = chunk.choices[0].delta.content
yield content

return AsyncPromptyStream("ServerlessProcessor", generator())
else:
return data
36 changes: 16 additions & 20 deletions runtime/prompty/prompty/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def _name(func: Callable, args):
if core_invoker:
name = type(args[0]).__name__
if signature.endswith("async"):
signature = f"{args[0].__module__}.{args[0].__class__.__name__}.invoke_async"
signature = (
f"{args[0].__module__}.{args[0].__class__.__name__}.invoke_async"
)
else:
signature = f"{args[0].__module__}.{args[0].__class__.__name__}.invoke"
else:
Expand All @@ -116,20 +118,19 @@ def _results(result: Any) -> dict:


def _trace_sync(
func: Callable = None, *, description: str = None, itemtype: str = None
func: Callable = None, **okwargs: Any
) -> Callable:
description = description or ""

@wraps(func)
def wrapper(*args, **kwargs):
name, signature = _name(func, args)
with Tracer.start(name) as trace:
trace("signature", signature)
if description and description != "":
trace("description", description)

if itemtype and itemtype != "":
trace("type", itemtype)
# support arbitrary keyword
# arguments for trace decorator
for k, v in okwargs.items():
trace(k, to_dict(v))

inputs = _inputs(func, args, kwargs)
trace("inputs", inputs)
Expand Down Expand Up @@ -161,20 +162,19 @@ def wrapper(*args, **kwargs):


def _trace_async(
func: Callable = None, *, description: str = None, itemtype: str = None
func: Callable = None, **okwargs: Any
) -> Callable:
description = description or ""

@wraps(func)
async def wrapper(*args, **kwargs):
name, signature = _name(func, args)
with Tracer.start(name) as trace:
trace("signature", signature)
if description and description != "":
trace("description", description)

if itemtype and itemtype != "":
trace("type", itemtype)
# support arbitrary keyword
# arguments for trace decorator
for k, v in okwargs.items():
trace(k, to_dict(v))

inputs = _inputs(func, args, kwargs)
trace("inputs", inputs)
Expand Down Expand Up @@ -204,15 +204,11 @@ async def wrapper(*args, **kwargs):
return wrapper


def trace(
func: Callable = None, *, description: str = None, itemtype: str = None
) -> Callable:
def trace(func: Callable = None, **kwargs: Any) -> Callable:
if func is None:
return partial(trace, description=description, itemtype=itemtype)

return partial(trace, **kwargs)
wrapped_method = _trace_async if inspect.iscoroutinefunction(func) else _trace_sync

return wrapped_method(func, description=description, itemtype=itemtype)
return wrapped_method(func, **kwargs)


class PromptyTracer:
Expand Down
2 changes: 1 addition & 1 deletion runtime/prompty/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ dependencies = [
[project.optional-dependencies]
azure = ["azure-identity>=1.17.1","openai>=1.35.10"]
openai = ["openai>=1.35.10"]
serverless = ["azure-ai-inference>=1.0.0b3"]
serverless = ["azure-identity>=1.17.1","azure-ai-inference>=1.0.0b3"]


[tool.pdm]
Expand Down
4 changes: 2 additions & 2 deletions runtime/prompty/tests/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ async def test_function_calling_async():
# need to add trace attribute to
# materialize stream into the function
# trace decorator
@trace
@trace(streaming=True, other="test")
def test_streaming():
result = prompty.execute(
"prompts/streaming.prompty",
Expand All @@ -254,7 +254,7 @@ def test_streaming():


@pytest.mark.asyncio
@trace
@trace(streaming=True)
async def test_streaming_async():
result = await prompty.execute_async(
"prompts/streaming.prompty",
Expand Down

0 comments on commit 5c4a5b3

Please sign in to comment.