From 14be1cf941c459ab44a1bcf45e97d32996dcd775 Mon Sep 17 00:00:00 2001 From: Daniel Lenton Date: Wed, 6 Nov 2024 15:36:02 +0000 Subject: [PATCH] added test_routing_w_models. --- tests/test_routing/test_routing_syntax.py | 6 ++++++ unify/universal_api/clients/helpers.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/tests/test_routing/test_routing_syntax.py b/tests/test_routing/test_routing_syntax.py index f689388..9963604 100644 --- a/tests/test_routing/test_routing_syntax.py +++ b/tests/test_routing/test_routing_syntax.py @@ -85,5 +85,11 @@ def test_routing_skip_providers(): ).generate("Hello.") +def test_routing_w_models(): + unify.Unify( + "router@q:1|i:0.5|models:gpt-4o,o1-preview,claude-3-sonnet", + ).generate("Hello.") + + if __name__ == "__main__": pass diff --git a/unify/universal_api/clients/helpers.py b/unify/universal_api/clients/helpers.py index a1e72ba..7992dd8 100644 --- a/unify/universal_api/clients/helpers.py +++ b/unify/universal_api/clients/helpers.py @@ -49,6 +49,23 @@ def _is_meta_provider(provider: str, api_key: str = None): provider = "".join([chnk0, chnk2]) if provider[-1] == "|": provider = provider[:-1] + public_models = unify.list_models(api_key=api_key) + if "skip_models:" in provider: + skip_mods = provider.split("skip_models:")[-1].split("|")[0] + for md in skip_mods.split(","): + if md.strip() not in public_models: + return False + chnk0, chnk1 = provider.split("skip_models:") + chnk2 = "|".join(chnk1.split("|")[1:]) + provider = "".join([chnk0, chnk2]) + if "models:" in provider: + mods = provider.split("models:")[-1].split("|")[0] + for md in mods.split(","): + if md.strip() not in public_models: + return False + chnk0, chnk1 = provider.split("models:") + chnk2 = "|".join(chnk1.split("|")[1:]) + provider = "".join([chnk0, chnk2]) meta_providers = ( ( "highest-quality",