Skip to content

Commit

Permalink
[FuncAPIWrapper][multiple endpoints]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Dec 18, 2023
1 parent ae7fe34 commit de6ad72
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 54 deletions.
97 changes: 54 additions & 43 deletions errors.txt
Original file line number Diff line number Diff line change
@@ -1,43 +1,54 @@
Initializing MLIR with module: _site_initialize_0
Registering dialects from initializer <module 'jaxlib.mlir._mlir_libs._site_initialize_0' from '/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/jaxlib/mlir/_mlir_libs/_site_initialize_0.so'>
etils.epath was not found. Using pathlib for file I/O.
[2023-12-18 11:15:00,058] torch.distributed.elastic.multiprocessing.redirects: [WARNING] NOTE: Redirects are currently not supported in Windows or MacOs.
Registering GET endpoint at /agent
Registering POST endpoint at /agent
Registering GET endpoint at /agent
Registering GET endpoint at /agent
Registering GET endpoint at /agent
Registering GET endpoint at /agent
Using selector: KqueueSelector
HTTP Request: GET http://testserver/test "HTTP/1.1 422 Unprocessable Entity"
Error in test_endpoint: Invalid method: invalid
Using selector: KqueueSelector
HTTP Request: POST http://testserver/test_post "HTTP/1.1 422 Unprocessable Entity"
Error in test_put_endpoint: Invalid method: put
Using selector: KqueueSelector
HTTP Request: PUT http://testserver/test_put "HTTP/1.1 404 Not Found"
Error in test_delete_endpoint: Invalid method: delete
Using selector: KqueueSelector
HTTP Request: DELETE http://testserver/test_delete "HTTP/1.1 404 Not Found"
Using selector: KqueueSelector
HTTP Request: GET http://testserver/test_error "HTTP/1.1 422 Unprocessable Entity"
Using selector: KqueueSelector
HTTP Request: GET http://testserver/test_rate_limit "HTTP/1.1 422 Unprocessable Entity"
Using selector: KqueueSelector
HTTP Request: GET http://testserver/test_rate_limit "HTTP/1.1 422 Unprocessable Entity"
Using selector: KqueueSelector
HTTP Request: GET http://testserver/test_rate_limit "HTTP/1.1 422 Unprocessable Entity"
Using selector: KqueueSelector
HTTP Request: GET http://testserver/test_rate_limit "HTTP/1.1 422 Unprocessable Entity"
Using selector: KqueueSelector
HTTP Request: GET http://testserver/test_rate_limit "HTTP/1.1 422 Unprocessable Entity"
Using selector: KqueueSelector
HTTP Request: GET http://testserver/test_rate_limit "HTTP/1.1 422 Unprocessable Entity"
Error in test_patch_endpoint: Invalid method: patch
Using selector: KqueueSelector
HTTP Request: PATCH http://testserver/test_patch "HTTP/1.1 404 Not Found"
Using selector: KqueueSelector
HTTP Request: POST http://testserver/test_data "HTTP/1.1 422 Unprocessable Entity"
Using selector: KqueueSelector
HTTP Request: GET http://testserver/test_params?key=value "HTTP/1.1 422 Unprocessable Entity"
Using selector: KqueueSelector
Traceback (most recent call last):
File "/Users/defalt/Desktop/Athena/research/swarms-cloud/tests/test_func_wrapper.py", line 8, in <module>
from swarms_cloud.func_api_wrapper import FuncAPIWrapper
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/swarms_cloud/__init__.py", line 1, in <module>
from swarms_cloud.main import agent_api_wrapper
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/swarms_cloud/main.py", line 5, in <module>
from swarms.structs.agent import Agent
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/swarms/__init__.py", line 8, in <module>
from swarms.models import * # noqa: E402, F403
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/swarms/models/__init__.py", line 4, in <module>
from swarms.models.petals import Petals # noqa: E402
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/swarms/models/petals.py", line 1, in <module>
from transformers import AutoTokenizer, AutoModelForCausalLM
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/transformers/__init__.py", line 26, in <module>
from . import dependency_versions_check
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/transformers/dependency_versions_check.py", line 16, in <module>
from .utils.versions import require_version, require_version_core
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/transformers/utils/__init__.py", line 31, in <module>
from .generic import (
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/transformers/utils/generic.py", line 29, in <module>
from .import_utils import is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 124, in <module>
_scipy_available = _is_package_available("scipy")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 47, in _is_package_available
package_version = importlib.metadata.version(pkg_name)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/metadata/__init__.py", line 991, in version
return distribution(distribution_name).version
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/metadata/__init__.py", line 628, in version
return self.metadata['Version']
^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/metadata/__init__.py", line 613, in metadata
return _adapters.Message(email.message_from_string(text))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/metadata/_adapters.py", line 36, in __init__
self._headers = self._repair_headers()
^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/metadata/_adapters.py", line 49, in _repair_headers
headers = [(key, redent(value)) for key, value in vars(self)['_headers']]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/metadata/_adapters.py", line 49, in <listcomp>
headers = [(key, redent(value)) for key, value in vars(self)['_headers']]
^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/metadata/_adapters.py", line 47, in redent
return textwrap.dedent(' ' * 8 + value)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/textwrap.py", line 435, in dedent
text = _whitespace_only_re.sub('', text)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt
75 changes: 73 additions & 2 deletions swarms_cloud/func_api_wrapper.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
import inspect
import logging
from typing import Callable
from typing import Callable, TypeVar, get_type_hints

import uvicorn
from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request, Response
from pydantic import create_model

# Logger initialization
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


# Genertic type for return type
T = TypeVar("T")


# Function API Wrapper for functions: [CLASS]
class FuncAPIWrapper:
"""Functional API Wrapper
Expand Down Expand Up @@ -43,6 +49,7 @@ def __init__(
self.host = host
self.port = port
self.app = FastAPI()
self.error_handlers = {}

def add(self, path: str, method: str = "post", *args, **kwargs):
"""Add an endpoint to the API
Expand Down Expand Up @@ -104,3 +111,67 @@ def __call__(self, *args, **kwargs):
self.run(*args, **kwargs)
except Exception as error:
logger.error(f"Error in {self.__class__.__name__}: {error}")

def add_endpoints(self, endpoints: list):
"""
Batch addition of multiple endpoints.
Args:
endpoints (list): A list of tuples, each containing path, method, and function.
Example: [("/path1", "get", function1), ("/path2", "post", function2)]
"""
for path, method, func in endpoints:
self.add(path, method)(func)

def _generate_request_model(self, func: Callable):
"""Generate requests model
Args:
func (Callable): function to generate the request model for
Returns:
: dynamically generated request model
"""
# Extract arguments from the function signature
signature = inspect.signature(func)
fields = {
name: (param.annotation, ...)
for name, param in signature.parameters.items()
}
return create_model(f"{func.__name__}Request", **fields)

def _generate_response_model(self, func: Callable):
"""Generate response model
Args:
func (Callable): function to generate the response model for
Returns:
: dynamically generated response model
"""
return_type = get_type_hints(func).get("return")
return create_model(f"{func.__name__}Response", result=(return_type, ...))

def add_error_handler(
self, exception_class: type, handler: Callable[[Request, Exception], Response]
):
"""Add an error handler
Args:
exception_class (type): exception class to handle
handler (Callable[[Request, Exception], Response]): handler function
"""
self.error_handlers[exception_class] = handler

async def _call_async_func(self, func, **kwargs):
"""Call an async function
Args:
func (callable): function to call
"""
if inspect.iscoroutinefunction(func):
return await func(**kwargs)
else:
return func(**kwargs)
18 changes: 9 additions & 9 deletions swarms_cloud/sky_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ class SkyInterface:
SkyInterface is a wrapper around the sky Python API. It provides a
simplified interface for launching, executing, stopping, starting, and
tearing down clusters.
Attributes:
clusters (dict): A dictionary of clusters that have been launched.
The keys are the names of the clusters and the values are the handles
to the clusters.
Methods:
launch: Launch a cluster
execute: Execute a task on a cluster
Expand All @@ -21,7 +21,7 @@ class SkyInterface:
down: Tear down a cluster
status: Get the status of a cluster
autostop: Set the autostop of a cluster
Example:
>>> sky_interface = SkyInterface()
>>> job_id = sky_interface.launch("task", "cluster_name")
Expand All @@ -42,7 +42,7 @@ def launch(self, task, cluster_name=None, **kwargs):
"""Launch a task on a cluster
Args:
task (_type_): _description_
task (str): code to execute on the cluster
cluster_name (_type_, optional): _description_. Defaults to None.
Returns:
Expand Down Expand Up @@ -80,7 +80,7 @@ def stop(self, cluster_name, **kwargs):
"""Stop a cluster
Args:
cluster_name (_type_): _description_
cluster_name (str): name of the cluster to stop
"""
try:
sky.stop(cluster_name, **kwargs)
Expand All @@ -91,7 +91,7 @@ def start(self, cluster_name, **kwargs):
"""start a cluster
Args:
cluster_name (_type_): _description_
cluster_name (str): name of the cluster to start
"""
try:
sky.start(cluster_name, **kwargs)
Expand All @@ -102,7 +102,7 @@ def down(self, cluster_name, **kwargs):
"""Down a cluster
Args:
cluster_name (_type_): _description_
cluster_name (str): name of the cluster to tear down
"""
try:
sky.down(cluster_name, **kwargs)
Expand All @@ -115,7 +115,7 @@ def status(self, **kwargs):
"""Save a cluster
Returns:
_type_: _description_
r: the status of the cluster
"""
try:
return sky.status(**kwargs)
Expand All @@ -126,7 +126,7 @@ def autostop(self, cluster_name, **kwargs):
"""Autostop a cluster
Args:
cluster_name (_type_): _description_
cluster_name (str): name of the cluster to autostop
"""
try:
sky.autostop(cluster_name, **kwargs)
Expand Down
36 changes: 36 additions & 0 deletions tests/test_func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,39 @@ def test_middleware_endpoint():
response = client.get("/test_middleware")
assert response.status_code == 200
assert response.headers["X-Custom-Header"] == "Test"



def test_add_endpoints(func_api_wrapper):
endpoints = [
("/test1", "get", lambda: {"message": "test1"}),
("/test2", "post", lambda: {"message": "test2"}),
]
func_api_wrapper.add_endpoints(endpoints)

client = TestClient(func_api_wrapper.app)
response = client.get("/test1")
assert response.status_code == 200
assert response.json() == {"message": "test1"}

response = client.post("/test2")
assert response.status_code == 200
assert response.json() == {"message": "test2"}

def test_add_endpoints_invalid_method(func_api_wrapper):
endpoints = [
("/test_invalid", "invalid", lambda: {"message": "test_invalid"}),
]
with pytest.raises(ValueError):
func_api_wrapper.add_endpoints(endpoints)

def test_add_endpoints_exception(func_api_wrapper):
endpoints = [
("/test_exception", "get", lambda: 1 / 0),
]
func_api_wrapper.add_endpoints(endpoints)

client = TestClient(func_api_wrapper.app)
response = client.get("/test_exception")
assert response.status_code == 500
assert "division by zero" in response.text

0 comments on commit de6ad72

Please sign in to comment.