Skip to content

Commit

Permalink
feat(groq): refactor groq, add groq tool call support (#1133)
Browse files Browse the repository at this point in the history
Co-authored-by: Roger Yang <[email protected]>
  • Loading branch information
cjunkin and RogerHYang authored Dec 20, 2024
1 parent f99d71a commit 6057418
Show file tree
Hide file tree
Showing 9 changed files with 697 additions and 127 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import asyncio
import os

from groq import AsyncGroq, Groq
from groq.types.chat import ChatCompletionToolMessageParam
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk import trace as trace_sdk
from opentelemetry.sdk.trace.export import SimpleSpanProcessor

from openinference.instrumentation.groq import GroqInstrumentor


def test():
client = Groq(
api_key=os.environ.get("GROQ_API_KEY"),
)

weather_function = {
"type": "function",
"function": {
"name": "get_weather",
"description": "finds the weather for a given city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for, e.g. 'London'",
}
},
"required": ["city"],
},
},
}

sys_prompt = "Respond to the user's query using the correct tool."
user_msg = "What's the weather like in San Francisco?"

messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": user_msg}]
response = client.chat.completions.create(
model="mixtral-8x7b-32768",
messages=messages,
temperature=0.0,
tools=[weather_function],
tool_choice="required",
)

message = response.choices[0].message
assert (tool_calls := message.tool_calls)
tool_call_id = tool_calls[0].id
messages.append(message)
messages.append(
ChatCompletionToolMessageParam(content="sunny", role="tool", tool_call_id=tool_call_id),
)
final_response = client.chat.completions.create(
model="mixtral-8x7b-32768",
messages=messages,
)
return final_response


async def async_test():
client = AsyncGroq(
api_key=os.environ.get("GROQ_API_KEY"),
)

weather_function = {
"type": "function",
"function": {
"name": "get_weather",
"description": "finds the weather for a given city",
"parameters": {
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The city to find the weather for, e.g. 'London'",
}
},
"required": ["city"],
},
},
}

sys_prompt = "Respond to the user's query using the correct tool."
user_msg = "What's the weather like in San Francisco?"

messages = [{"role": "system", "content": sys_prompt}, {"role": "user", "content": user_msg}]
response = await client.chat.completions.create(
model="mixtral-8x7b-32768",
messages=messages,
temperature=0.0,
tools=[weather_function],
tool_choice="required",
)

message = response.choices[0].message
assert (tool_calls := message.tool_calls)
tool_call_id = tool_calls[0].id
messages.append(message)
messages.append(
ChatCompletionToolMessageParam(content="sunny", role="tool", tool_call_id=tool_call_id),
)
final_response = await client.chat.completions.create(
model="mixtral-8x7b-32768",
messages=messages,
)
return final_response


if __name__ == "__main__":
endpoint = "http://0.0.0.0:6006/v1/traces"
tracer_provider = trace_sdk.TracerProvider()
tracer_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))

GroqInstrumentor().instrument(tracer_provider=tracer_provider)

response = test()
print("Response\n--------")
print(response)

async_response = asyncio.run(async_test())
print("\nAsync Response\n--------")
print(async_response)
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from typing import Any, Collection

from opentelemetry import trace as trace_api
from opentelemetry.instrumentation.instrumentor import ( # type: ignore[attr-defined]
BaseInstrumentor,
)
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor # type: ignore
from wrapt import wrap_function_wrapper

from groq.resources.chat.completions import AsyncCompletions, Completions
Expand All @@ -17,6 +15,7 @@
from openinference.instrumentation.groq.version import __version__

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

_instruments = ("groq >= 0.9.0",)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,6 @@
import logging
from typing import (
Any,
Iterable,
Iterator,
Mapping,
Tuple,
TypeVar,
)
from enum import Enum
from typing import Any, Iterable, Iterator, Mapping, Tuple, TypeVar

from opentelemetry.util.types import AttributeValue

Expand All @@ -16,6 +10,7 @@
MessageAttributes,
OpenInferenceSpanKindValues,
SpanAttributes,
ToolCallAttributes,
)

__all__ = ("_RequestAttributesExtractor",)
Expand Down Expand Up @@ -49,31 +44,88 @@ def get_extra_attributes_from_request(
if not isinstance(request_parameters, Mapping):
return
invocation_params = dict(request_parameters)
invocation_params.pop("messages", None)
invocation_params.pop("messages", None) # Remove LLM input messages
invocation_params.pop("functions", None)
invocation_params.pop("tools", None)

if isinstance((tools := invocation_params.pop("tools", None)), Iterable):
for i, tool in enumerate(tools):
yield f"llm.tools.{i}.tool.json_schema", safe_json_dumps(tool)

yield SpanAttributes.LLM_INVOCATION_PARAMETERS, safe_json_dumps(invocation_params)

if (input_messages := request_parameters.get("messages")) and isinstance(
input_messages, Iterable
):
# Use reversed() to get the last message first. This is because OTEL has a default
# limit of 128 attributes per span, and flattening increases the number of
# attributes very quickly.
for index, input_message in reversed(list(enumerate(input_messages))):
if role := input_message.get("role"):
yield (
f"{SpanAttributes.LLM_INPUT_MESSAGES}.{index}.{MessageAttributes.MESSAGE_ROLE}",
role,
)
if content := input_message.get("content"):
# Use reversed() to get the last message first. This is because OTEL has a default
# limit of 128 attributes per span, and flattening increases the number of
# attributes very quickly.
for key, value in self._get_attributes_from_message_param(input_message):
yield f"{SpanAttributes.LLM_INPUT_MESSAGES}.{index}.{key}", value

def _get_attributes_from_message_param(
self,
message: Mapping[str, Any],
) -> Iterator[Tuple[str, AttributeValue]]:
if role := get_attribute(message, "role"):
yield (
MessageAttributes.MESSAGE_ROLE,
role.value if isinstance(role, Enum) else role,
)
if content := get_attribute(message, "content"):
yield (
MessageAttributes.MESSAGE_CONTENT,
content,
)
if name := get_attribute(message, "name"):
yield MessageAttributes.MESSAGE_NAME, name

if tool_call_id := get_attribute(message, "tool_call_id"):
yield MessageAttributes.MESSAGE_TOOL_CALL_ID, tool_call_id

# Deprecated by Groq
if function_call := get_attribute(message, "function_call"):
if function_name := get_attribute(function_call, "name"):
yield MessageAttributes.MESSAGE_FUNCTION_CALL_NAME, function_name
if function_arguments := get_attribute(function_call, "arguments"):
yield (
MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON,
function_arguments,
)

if (tool_calls := get_attribute(message, "tool_calls")) and isinstance(
tool_calls, Iterable
):
for index, tool_call in enumerate(tool_calls):
if (tool_call_id := get_attribute(tool_call, "id")) is not None:
yield (
f"{SpanAttributes.LLM_INPUT_MESSAGES}.{index}.{MessageAttributes.MESSAGE_CONTENT}",
content,
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
f"{ToolCallAttributes.TOOL_CALL_ID}",
tool_call_id,
)
if function := get_attribute(tool_call, "function"):
if name := get_attribute(function, "name"):
yield (
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
f"{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}",
name,
)
if arguments := get_attribute(function, "arguments"):
yield (
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
f"{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
arguments,
)


T = TypeVar("T", bound=type)


def is_iterable_of(lst: Iterable[object], tp: T) -> bool:
return isinstance(lst, Iterable) and all(isinstance(x, tp) for x in lst)


def get_attribute(obj: Any, attr_name: str, default: Any = None) -> Any:
if isinstance(obj, dict):
return obj.get(attr_name, default)
return getattr(obj, attr_name, default)
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import logging
from typing import Any, Iterable, Iterator, Mapping, Tuple

from opentelemetry.util.types import AttributeValue

from openinference.instrumentation.groq._utils import _as_output_attributes, _io_value_and_type
from openinference.semconv.trace import MessageAttributes, SpanAttributes, ToolCallAttributes

__all__ = ("_ResponseAttributesExtractor",)

logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())


class _ResponseAttributesExtractor:
def get_attributes(self, response: Any) -> Iterator[Tuple[str, AttributeValue]]:
yield from _as_output_attributes(
_io_value_and_type(response),
)

def get_extra_attributes(
self,
response: Any,
request_parameters: Mapping[str, Any],
) -> Iterator[Tuple[str, AttributeValue]]:
yield from self._get_attributes_from_chat_completion(
completion=response,
request_parameters=request_parameters,
)

def _get_attributes_from_chat_completion(
self,
completion: Any,
request_parameters: Mapping[str, Any],
) -> Iterator[Tuple[str, AttributeValue]]:
if model := getattr(completion, "model", None):
yield SpanAttributes.LLM_MODEL_NAME, model
if usage := getattr(completion, "usage", None):
yield from self._get_attributes_from_completion_usage(usage)
if (choices := getattr(completion, "choices", None)) and isinstance(choices, Iterable):
for choice in choices:
if (index := getattr(choice, "index", None)) is None:
continue
if message := getattr(choice, "message", None):
for key, value in self._get_attributes_from_chat_completion_message(message):
yield f"{SpanAttributes.LLM_OUTPUT_MESSAGES}.{index}.{key}", value

def _get_attributes_from_chat_completion_message(
self,
message: object,
) -> Iterator[Tuple[str, AttributeValue]]:
if role := getattr(message, "role", None):
yield MessageAttributes.MESSAGE_ROLE, role
if content := getattr(message, "content", None):
yield MessageAttributes.MESSAGE_CONTENT, content
if function_call := getattr(message, "function_call", None):
if name := getattr(function_call, "name", None):
yield MessageAttributes.MESSAGE_FUNCTION_CALL_NAME, name
if arguments := getattr(function_call, "arguments", None):
yield MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON, arguments
if (tool_calls := getattr(message, "tool_calls", None)) and isinstance(
tool_calls, Iterable
):
for index, tool_call in enumerate(tool_calls):
if (tool_call_id := getattr(tool_call, "id", None)) is not None:
# https://github.com/groq/groq-python/blob/fa2e13b5747d18aeb478700f1e5426af2fd087a1/src/groq/types/chat/chat_completion_tool_message_param.py#L17 # noqa: E501
yield (
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
f"{ToolCallAttributes.TOOL_CALL_ID}",
tool_call_id,
)
if function := getattr(tool_call, "function", None):
# https://github.com/groq/groq-python/blob/fa2e13b5747d18aeb478700f1e5426af2fd087a1/src/groq/types/chat/chat_completion_message_tool_call.py#L10 # noqa: E501
if name := getattr(function, "name", None):
yield (
(
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
f"{ToolCallAttributes.TOOL_CALL_FUNCTION_NAME}"
),
name,
)
if arguments := getattr(function, "arguments", None):
yield (
f"{MessageAttributes.MESSAGE_TOOL_CALLS}.{index}."
f"{ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON}",
arguments,
)

def _get_attributes_from_completion_usage(
self,
usage: object,
) -> Iterator[Tuple[str, AttributeValue]]:
if (total_tokens := getattr(usage, "total_tokens", None)) is not None:
yield SpanAttributes.LLM_TOKEN_COUNT_TOTAL, total_tokens
if (prompt_tokens := getattr(usage, "prompt_tokens", None)) is not None:
yield SpanAttributes.LLM_TOKEN_COUNT_PROMPT, prompt_tokens
if (completion_tokens := getattr(usage, "completion_tokens", None)) is not None:
yield SpanAttributes.LLM_TOKEN_COUNT_COMPLETION, completion_tokens
Loading

0 comments on commit 6057418

Please sign in to comment.