Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Only show providers if dependencies are installed #5251

Merged
merged 6 commits into from
Nov 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,8 @@ input GenerativeModelInput {
type GenerativeProvider {
name: String!
key: GenerativeProviderKey!
dependencies: [String!]!
dependenciesInstalled: Boolean!
}

enum GenerativeProviderKey {
Expand Down
27 changes: 27 additions & 0 deletions src/phoenix/server/api/helpers/playground_clients.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib.util
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator, Callable, Iterator
from typing import (
Expand Down Expand Up @@ -46,6 +47,7 @@
ChatCompletionMessageToolCallParam,
)

DependencyName: TypeAlias = str
SetSpanAttributesFn: TypeAlias = Callable[[Mapping[str, Any]], None]
ChatCompletionChunk: TypeAlias = Union[TextChunk, ToolCallChunk]

Expand All @@ -59,6 +61,12 @@ def __init__(
) -> None:
self._set_span_attributes = set_span_attributes

@classmethod
@abstractmethod
def dependencies(cls) -> list[DependencyName]:
# A list of dependency names this client needs to run
...

@classmethod
@abstractmethod
def supported_invocation_parameters(cls) -> list[InvocationParameter]: ...
Expand Down Expand Up @@ -97,6 +105,17 @@ def construct_invocation_parameters(
validate_invocation_parameters(supported_params, formatted_invocation_parameters)
return formatted_invocation_parameters

@classmethod
def dependencies_are_installed(cls) -> bool:
try:
for dependency in cls.dependencies():
if importlib.util.find_spec(dependency) is None:
return False
return True
except ValueError:
# happens in some cases if the spec is None
return False


@register_llm_client(
provider_key=GenerativeProviderKey.OPENAI,
Expand Down Expand Up @@ -134,6 +153,10 @@ def __init__(
self.client = AsyncOpenAI(api_key=api_key)
self.model_name = model.name

@classmethod
def dependencies(cls) -> list[DependencyName]:
return ["openai"]

@classmethod
def supported_invocation_parameters(cls) -> list[InvocationParameter]:
return [
Expand Down Expand Up @@ -522,6 +545,10 @@ def __init__(
self.client = anthropic.AsyncAnthropic(api_key=api_key)
self.model_name = model.name

@classmethod
def dependencies(cls) -> list[DependencyName]:
return ["anthropic"]

@classmethod
def supported_invocation_parameters(cls) -> list[InvocationParameter]:
return [
Expand Down
5 changes: 5 additions & 0 deletions src/phoenix/server/api/helpers/playground_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,11 @@ def get_client(
client_class = provider_registry[PROVIDER_DEFAULT] # Fallback to provider default
return client_class

def list_all_providers(
self,
) -> list[GenerativeProviderKey]:
return [provider_key for provider_key in self._registry]

def list_models(self, provider_key: GenerativeProviderKey) -> list[str]:
provider_registry = self._registry.get(provider_key, {})
return [model_name for model_name in provider_registry.keys() if model_name is not None]
Expand Down
19 changes: 8 additions & 11 deletions src/phoenix/server/api/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from phoenix.server.api.context import Context
from phoenix.server.api.exceptions import NotFound, Unauthorized
from phoenix.server.api.helpers import ensure_list
from phoenix.server.api.helpers.playground_clients import initialize_playground_clients
from phoenix.server.api.helpers.playground_registry import PLAYGROUND_CLIENT_REGISTRY
from phoenix.server.api.input_types.ClusterInput import ClusterInput
from phoenix.server.api.input_types.Coordinates import (
Expand Down Expand Up @@ -84,6 +85,8 @@
from phoenix.server.api.types.UserApiKey import UserApiKey, to_gql_api_key
from phoenix.server.api.types.UserRole import UserRole

initialize_playground_clients()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add this once inside create_app?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good question, I think probably yes



@strawberry.input
class ModelsInput:
Expand All @@ -95,19 +98,13 @@ class ModelsInput:
class Query:
@strawberry.field
async def model_providers(self) -> list[GenerativeProvider]:
available_providers = PLAYGROUND_CLIENT_REGISTRY.list_all_providers()
return [
GenerativeProvider(
name="OpenAI",
key=GenerativeProviderKey.OPENAI,
),
GenerativeProvider(
name="Azure OpenAI",
key=GenerativeProviderKey.AZURE_OPENAI,
),
GenerativeProvider(
name="Anthropic",
key=GenerativeProviderKey.ANTHROPIC,
),
name=provider_key.value,
key=provider_key,
)
for provider_key in available_providers
]

@strawberry.field
Expand Down
30 changes: 27 additions & 3 deletions src/phoenix/server/api/types/GenerativeProvider.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,36 @@

@strawberry.enum
class GenerativeProviderKey(Enum):
OPENAI = "OPENAI"
ANTHROPIC = "ANTHROPIC"
AZURE_OPENAI = "AZURE_OPENAI"
OPENAI = "OpenAI"
ANTHROPIC = "Anthropic"
AZURE_OPENAI = "Azure OpenAI"


@strawberry.type
class GenerativeProvider:
name: str
key: GenerativeProviderKey

@strawberry.field
async def dependencies(self) -> list[str]:
from phoenix.server.api.helpers.playground_registry import (
PLAYGROUND_CLIENT_REGISTRY,
PROVIDER_DEFAULT,
)

default_client = PLAYGROUND_CLIENT_REGISTRY.get_client(self.key, PROVIDER_DEFAULT)
if default_client:
return default_client.dependencies()
return []

@strawberry.field
async def dependencies_installed(self) -> bool:
from phoenix.server.api.helpers.playground_registry import (
PLAYGROUND_CLIENT_REGISTRY,
PROVIDER_DEFAULT,
)

default_client = PLAYGROUND_CLIENT_REGISTRY.get_client(self.key, PROVIDER_DEFAULT)
if default_client:
return default_client.dependencies_are_installed()
return False
Loading