diff --git a/runtime/prompty/prompty/serverless/executor.py b/runtime/prompty/prompty/serverless/executor.py index c912490..58a8dc6 100644 --- a/runtime/prompty/prompty/serverless/executor.py +++ b/runtime/prompty/prompty/serverless/executor.py @@ -1,3 +1,4 @@ +import azure.identity import importlib.metadata from typing import Iterator from azure.core.credentials import AzureKeyCredential @@ -5,6 +6,11 @@ ChatCompletionsClient, EmbeddingsClient, ) + +from azure.ai.inference.aio import ( + ChatCompletionsClient as AsyncChatCompletionsClient, + EmbeddingsClient as AsyncEmbeddingsClient, +) from azure.ai.inference.models import ( StreamingChatCompletions, AsyncStreamingChatCompletions, @@ -24,10 +30,18 @@ class ServerlessExecutor(Invoker): def __init__(self, prompty: Prompty) -> None: super().__init__(prompty) - # serverless configuration self.endpoint = self.prompty.model.configuration["endpoint"] self.model = self.prompty.model.configuration["model"] - self.key = self.prompty.model.configuration["key"] + + # no key, use default credentials + if "key" not in self.kwargs: + self.credential = azure.identity.DefaultAzureCredential( + exclude_shared_token_cache_credential=True + ) + else: + self.credential = AzureKeyCredential( + self.prompty.model.configuration["key"] + ) # api type self.api = self.prompty.model.api @@ -64,7 +78,7 @@ def invoke(self, data: any) -> any: cargs = { "endpoint": self.endpoint, - "credential": AzureKeyCredential(self.key), + "credential": self.credential, } if self.api == "chat": @@ -150,4 +164,77 @@ async def invoke_async(self, data: str) -> str: str The parsed data """ - return self.invoke(data) + cargs = { + "endpoint": self.endpoint, + "credential": self.credential, + } + + if self.api == "chat": + with Tracer.start("ChatCompletionsClient") as trace: + trace("type", "LLM") + trace("signature", "azure.ai.inference.aio.ChatCompletionsClient.ctor") + trace( + "description", "Azure Unified Inference SDK Async Chat Completions Client" + ) + trace("inputs", cargs) + client = AsyncChatCompletionsClient( + user_agent=f"prompty/{VERSION}", + **cargs, + ) + trace("result", client) + + with Tracer.start("complete") as trace: + trace("type", "LLM") + trace("signature", "azure.ai.inference.ChatCompletionsClient.complete") + trace( + "description", "Azure Unified Inference SDK Async Chat Completions Client" + ) + eargs = { + "model": self.model, + "messages": data if isinstance(data, list) else [data], + **self.prompty.model.parameters, + } + trace("inputs", eargs) + r = await client.complete(**eargs) + trace("result", r) + + response = self._response(r) + + elif self.api == "completion": + raise NotImplementedError( + "Serverless Completions API is not implemented yet" + ) + + elif self.api == "embedding": + with Tracer.start("EmbeddingsClient") as trace: + trace("type", "LLM") + trace("signature", "azure.ai.inference.aio.EmbeddingsClient.ctor") + trace("description", "Azure Unified Inference SDK Async Embeddings Client") + trace("inputs", cargs) + client = AsyncEmbeddingsClient( + user_agent=f"prompty/{VERSION}", + **cargs, + ) + trace("result", client) + + with Tracer.start("complete") as trace: + trace("type", "LLM") + trace("signature", "azure.ai.inference.ChatCompletionsClient.complete") + trace( + "description", "Azure Unified Inference SDK Chat Completions Client" + ) + eargs = { + "model": self.model, + "input": data if isinstance(data, list) else [data], + **self.prompty.model.parameters, + } + trace("inputs", eargs) + r = await client.complete(**eargs) + trace("result", r) + + response = self._response(r) + + elif self.api == "image": + raise NotImplementedError("Azure OpenAI Image API is not implemented yet") + + return response diff --git a/runtime/prompty/prompty/serverless/processor.py b/runtime/prompty/prompty/serverless/processor.py index 98e1070..32c0b31 100644 --- a/runtime/prompty/prompty/serverless/processor.py +++ b/runtime/prompty/prompty/serverless/processor.py @@ -1,6 +1,6 @@ -from typing import Iterator +from typing import AsyncIterator, Iterator from ..invoker import Invoker, InvokerFactory -from ..core import Prompty, PromptyStream, ToolCall +from ..core import AsyncPromptyStream, Prompty, PromptyStream, ToolCall from azure.ai.inference.models import ChatCompletions, EmbeddingsResult @@ -75,4 +75,39 @@ async def invoke_async(self, data: str) -> str: str The parsed data """ - return self.invoke(data) + if isinstance(data, ChatCompletions): + response = data.choices[0].message + # tool calls available in response + if response.tool_calls: + return [ + ToolCall( + id=tool_call.id, + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ) + for tool_call in response.tool_calls + ] + else: + return response.content + + elif isinstance(data, EmbeddingsResult): + if len(data.data) == 0: + raise ValueError("Invalid data") + elif len(data.data) == 1: + return data.data[0].embedding + else: + return [item.embedding for item in data.data] + elif isinstance(data, AsyncIterator): + + async def generator(): + async for chunk in data: + if ( + len(chunk.choices) == 1 + and chunk.choices[0].delta.content != None + ): + content = chunk.choices[0].delta.content + yield content + + return AsyncPromptyStream("ServerlessProcessor", generator()) + else: + return data diff --git a/runtime/prompty/prompty/tracer.py b/runtime/prompty/prompty/tracer.py index 417cd24..8653d34 100644 --- a/runtime/prompty/prompty/tracer.py +++ b/runtime/prompty/prompty/tracer.py @@ -93,7 +93,9 @@ def _name(func: Callable, args): if core_invoker: name = type(args[0]).__name__ if signature.endswith("async"): - signature = f"{args[0].__module__}.{args[0].__class__.__name__}.invoke_async" + signature = ( + f"{args[0].__module__}.{args[0].__class__.__name__}.invoke_async" + ) else: signature = f"{args[0].__module__}.{args[0].__class__.__name__}.invoke" else: @@ -116,20 +118,19 @@ def _results(result: Any) -> dict: def _trace_sync( - func: Callable = None, *, description: str = None, itemtype: str = None + func: Callable = None, **okwargs: Any ) -> Callable: - description = description or "" @wraps(func) def wrapper(*args, **kwargs): name, signature = _name(func, args) with Tracer.start(name) as trace: trace("signature", signature) - if description and description != "": - trace("description", description) - if itemtype and itemtype != "": - trace("type", itemtype) + # support arbitrary keyword + # arguments for trace decorator + for k, v in okwargs.items(): + trace(k, to_dict(v)) inputs = _inputs(func, args, kwargs) trace("inputs", inputs) @@ -161,20 +162,19 @@ def wrapper(*args, **kwargs): def _trace_async( - func: Callable = None, *, description: str = None, itemtype: str = None + func: Callable = None, **okwargs: Any ) -> Callable: - description = description or "" @wraps(func) async def wrapper(*args, **kwargs): name, signature = _name(func, args) with Tracer.start(name) as trace: trace("signature", signature) - if description and description != "": - trace("description", description) - if itemtype and itemtype != "": - trace("type", itemtype) + # support arbitrary keyword + # arguments for trace decorator + for k, v in okwargs.items(): + trace(k, to_dict(v)) inputs = _inputs(func, args, kwargs) trace("inputs", inputs) @@ -204,15 +204,11 @@ async def wrapper(*args, **kwargs): return wrapper -def trace( - func: Callable = None, *, description: str = None, itemtype: str = None -) -> Callable: +def trace(func: Callable = None, **kwargs: Any) -> Callable: if func is None: - return partial(trace, description=description, itemtype=itemtype) - + return partial(trace, **kwargs) wrapped_method = _trace_async if inspect.iscoroutinefunction(func) else _trace_sync - - return wrapped_method(func, description=description, itemtype=itemtype) + return wrapped_method(func, **kwargs) class PromptyTracer: diff --git a/runtime/prompty/pyproject.toml b/runtime/prompty/pyproject.toml index 7736e64..a5a8529 100644 --- a/runtime/prompty/pyproject.toml +++ b/runtime/prompty/pyproject.toml @@ -19,7 +19,7 @@ dependencies = [ [project.optional-dependencies] azure = ["azure-identity>=1.17.1","openai>=1.35.10"] openai = ["openai>=1.35.10"] -serverless = ["azure-ai-inference>=1.0.0b3"] +serverless = ["azure-identity>=1.17.1","azure-ai-inference>=1.0.0b3"] [tool.pdm] diff --git a/runtime/prompty/tests/test_tracing.py b/runtime/prompty/tests/test_tracing.py index c89c596..8749568 100644 --- a/runtime/prompty/tests/test_tracing.py +++ b/runtime/prompty/tests/test_tracing.py @@ -241,7 +241,7 @@ async def test_function_calling_async(): # need to add trace attribute to # materialize stream into the function # trace decorator -@trace +@trace(streaming=True, other="test") def test_streaming(): result = prompty.execute( "prompts/streaming.prompty", @@ -254,7 +254,7 @@ def test_streaming(): @pytest.mark.asyncio -@trace +@trace(streaming=True) async def test_streaming_async(): result = await prompty.execute_async( "prompts/streaming.prompty",