Skip to content

Commit

Permalink
Clean up types
Browse files Browse the repository at this point in the history
  • Loading branch information
anticorrelator committed Oct 31, 2024
1 parent 782e222 commit 6967e35
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions src/phoenix/server/api/helpers/playground_registry.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Callable, Optional, Type, TypeAlias, Union

from phoenix.server.api.types.GenerativeProvider import GenerativeProviderKey

if TYPE_CHECKING:
from phoenix.server.api.subscriptions import PlaygroundStreamingClient

ModelKey = Tuple[GenerativeProviderKey, str | None]
ModelName: TypeAlias = Union[str, None]
ModelKey = tuple[GenerativeProviderKey, ModelName]

PROVIDER_DEFAULT = None


class SingletonMeta(type):
_instances: Dict[Any, Any] = dict()
_instances: dict[Any, Any] = dict()

def __call__(cls, *args: Any, **kwargs: Any) -> Any:
if cls not in cls._instances:
Expand All @@ -21,26 +22,26 @@ def __call__(cls, *args: Any, **kwargs: Any) -> Any:

class PlaygroundClientRegistry(metaclass=SingletonMeta):
def __init__(self) -> None:
self._registry: Dict[
GenerativeProviderKey, Dict[str | None, Type["PlaygroundStreamingClient"]]
self._registry: dict[
GenerativeProviderKey, dict[ModelName, Optional[Type["PlaygroundStreamingClient"]]]
] = {}

def get_client(
self,
provider_key: GenerativeProviderKey,
model_name: str | None,
model_name: ModelName,
) -> Optional[Type["PlaygroundStreamingClient"]]:
provider_registry = self._registry.get(provider_key, {})
client_class = provider_registry.get(model_name)
if client_class is None and None in provider_registry:
client_class = provider_registry[PROVIDER_DEFAULT] # Fallback to provider default
return client_class

def list_models(self, provider_key: GenerativeProviderKey) -> List[str]:
def list_models(self, provider_key: GenerativeProviderKey) -> list[str]:
provider_registry = self._registry.get(provider_key, {})
return [model_name for model_name in provider_registry.keys() if model_name is not None]

def list_all_models(self) -> List[ModelKey]:
def list_all_models(self) -> list[ModelKey]:
return [
(provider_key, model_name)
for provider_key, provider_registry in self._registry.items()
Expand All @@ -53,7 +54,7 @@ def list_all_models(self) -> List[ModelKey]:

def register_llm_client(
provider_key: GenerativeProviderKey,
model_names: List[str | None],
model_names: list[ModelName],
) -> Callable[[Type["PlaygroundStreamingClient"]], Type["PlaygroundStreamingClient"]]:
def decorator(cls: Type["PlaygroundStreamingClient"]) -> Type["PlaygroundStreamingClient"]:
provider_registry = PLAYGROUND_CLIENT_REGISTRY._registry.setdefault(provider_key, {})
Expand Down

0 comments on commit 6967e35

Please sign in to comment.