diff --git a/errors.txt b/errors.txt index eb81eb5..4b29793 100644 --- a/errors.txt +++ b/errors.txt @@ -1,43 +1,54 @@ -Initializing MLIR with module: _site_initialize_0 -Registering dialects from initializer -etils.epath was not found. Using pathlib for file I/O. -[2023-12-18 11:15:00,058] torch.distributed.elastic.multiprocessing.redirects: [WARNING] NOTE: Redirects are currently not supported in Windows or MacOs. -Registering GET endpoint at /agent -Registering POST endpoint at /agent -Registering GET endpoint at /agent -Registering GET endpoint at /agent -Registering GET endpoint at /agent -Registering GET endpoint at /agent -Using selector: KqueueSelector -HTTP Request: GET http://testserver/test "HTTP/1.1 422 Unprocessable Entity" -Error in test_endpoint: Invalid method: invalid -Using selector: KqueueSelector -HTTP Request: POST http://testserver/test_post "HTTP/1.1 422 Unprocessable Entity" -Error in test_put_endpoint: Invalid method: put -Using selector: KqueueSelector -HTTP Request: PUT http://testserver/test_put "HTTP/1.1 404 Not Found" -Error in test_delete_endpoint: Invalid method: delete -Using selector: KqueueSelector -HTTP Request: DELETE http://testserver/test_delete "HTTP/1.1 404 Not Found" -Using selector: KqueueSelector -HTTP Request: GET http://testserver/test_error "HTTP/1.1 422 Unprocessable Entity" -Using selector: KqueueSelector -HTTP Request: GET http://testserver/test_rate_limit "HTTP/1.1 422 Unprocessable Entity" -Using selector: KqueueSelector -HTTP Request: GET http://testserver/test_rate_limit "HTTP/1.1 422 Unprocessable Entity" -Using selector: KqueueSelector -HTTP Request: GET http://testserver/test_rate_limit "HTTP/1.1 422 Unprocessable Entity" -Using selector: KqueueSelector -HTTP Request: GET http://testserver/test_rate_limit "HTTP/1.1 422 Unprocessable Entity" -Using selector: KqueueSelector -HTTP Request: GET http://testserver/test_rate_limit "HTTP/1.1 422 Unprocessable Entity" -Using selector: KqueueSelector -HTTP Request: GET http://testserver/test_rate_limit "HTTP/1.1 422 Unprocessable Entity" -Error in test_patch_endpoint: Invalid method: patch -Using selector: KqueueSelector -HTTP Request: PATCH http://testserver/test_patch "HTTP/1.1 404 Not Found" -Using selector: KqueueSelector -HTTP Request: POST http://testserver/test_data "HTTP/1.1 422 Unprocessable Entity" -Using selector: KqueueSelector -HTTP Request: GET http://testserver/test_params?key=value "HTTP/1.1 422 Unprocessable Entity" -Using selector: KqueueSelector +Traceback (most recent call last): + File "/Users/defalt/Desktop/Athena/research/swarms-cloud/tests/test_func_wrapper.py", line 8, in + from swarms_cloud.func_api_wrapper import FuncAPIWrapper + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/swarms_cloud/__init__.py", line 1, in + from swarms_cloud.main import agent_api_wrapper + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/swarms_cloud/main.py", line 5, in + from swarms.structs.agent import Agent + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/swarms/__init__.py", line 8, in + from swarms.models import * # noqa: E402, F403 + ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/swarms/models/__init__.py", line 4, in + from swarms.models.petals import Petals # noqa: E402 + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/swarms/models/petals.py", line 1, in + from transformers import AutoTokenizer, AutoModelForCausalLM + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/transformers/__init__.py", line 26, in + from . import dependency_versions_check + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/transformers/dependency_versions_check.py", line 16, in + from .utils.versions import require_version, require_version_core + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/transformers/utils/__init__.py", line 31, in + from .generic import ( + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/transformers/utils/generic.py", line 29, in + from .import_utils import is_flax_available, is_tf_available, is_torch_available, is_torch_fx_proxy + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 124, in + _scipy_available = _is_package_available("scipy") + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/transformers/utils/import_utils.py", line 47, in _is_package_available + package_version = importlib.metadata.version(pkg_name) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/metadata/__init__.py", line 991, in version + return distribution(distribution_name).version + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/metadata/__init__.py", line 628, in version + return self.metadata['Version'] + ^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/metadata/__init__.py", line 613, in metadata + return _adapters.Message(email.message_from_string(text)) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/metadata/_adapters.py", line 36, in __init__ + self._headers = self._repair_headers() + ^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/metadata/_adapters.py", line 49, in _repair_headers + headers = [(key, redent(value)) for key, value in vars(self)['_headers']] + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/metadata/_adapters.py", line 49, in + headers = [(key, redent(value)) for key, value in vars(self)['_headers']] + ^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/importlib/metadata/_adapters.py", line 47, in redent + return textwrap.dedent(' ' * 8 + value) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + File "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/textwrap.py", line 435, in dedent + text = _whitespace_only_re.sub('', text) + ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +KeyboardInterrupt diff --git a/swarms_cloud/func_api_wrapper.py b/swarms_cloud/func_api_wrapper.py index c828692..fcc68b6 100644 --- a/swarms_cloud/func_api_wrapper.py +++ b/swarms_cloud/func_api_wrapper.py @@ -1,14 +1,20 @@ +import inspect import logging -from typing import Callable +from typing import Callable, TypeVar, get_type_hints import uvicorn -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Request, Response +from pydantic import create_model # Logger initialization logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +# Genertic type for return type +T = TypeVar("T") + + # Function API Wrapper for functions: [CLASS] class FuncAPIWrapper: """Functional API Wrapper @@ -43,6 +49,7 @@ def __init__( self.host = host self.port = port self.app = FastAPI() + self.error_handlers = {} def add(self, path: str, method: str = "post", *args, **kwargs): """Add an endpoint to the API @@ -104,3 +111,67 @@ def __call__(self, *args, **kwargs): self.run(*args, **kwargs) except Exception as error: logger.error(f"Error in {self.__class__.__name__}: {error}") + + def add_endpoints(self, endpoints: list): + """ + Batch addition of multiple endpoints. + + Args: + endpoints (list): A list of tuples, each containing path, method, and function. + Example: [("/path1", "get", function1), ("/path2", "post", function2)] + + """ + for path, method, func in endpoints: + self.add(path, method)(func) + + def _generate_request_model(self, func: Callable): + """Generate requests model + + Args: + func (Callable): function to generate the request model for + + Returns: + : dynamically generated request model + """ + # Extract arguments from the function signature + signature = inspect.signature(func) + fields = { + name: (param.annotation, ...) + for name, param in signature.parameters.items() + } + return create_model(f"{func.__name__}Request", **fields) + + def _generate_response_model(self, func: Callable): + """Generate response model + + Args: + func (Callable): function to generate the response model for + + Returns: + : dynamically generated response model + """ + return_type = get_type_hints(func).get("return") + return create_model(f"{func.__name__}Response", result=(return_type, ...)) + + def add_error_handler( + self, exception_class: type, handler: Callable[[Request, Exception], Response] + ): + """Add an error handler + + Args: + exception_class (type): exception class to handle + handler (Callable[[Request, Exception], Response]): handler function + """ + self.error_handlers[exception_class] = handler + + async def _call_async_func(self, func, **kwargs): + """Call an async function + + Args: + func (callable): function to call + + """ + if inspect.iscoroutinefunction(func): + return await func(**kwargs) + else: + return func(**kwargs) \ No newline at end of file diff --git a/swarms_cloud/sky_api.py b/swarms_cloud/sky_api.py index ce8eadb..9abd303 100644 --- a/swarms_cloud/sky_api.py +++ b/swarms_cloud/sky_api.py @@ -7,12 +7,12 @@ class SkyInterface: SkyInterface is a wrapper around the sky Python API. It provides a simplified interface for launching, executing, stopping, starting, and tearing down clusters. - + Attributes: clusters (dict): A dictionary of clusters that have been launched. The keys are the names of the clusters and the values are the handles to the clusters. - + Methods: launch: Launch a cluster execute: Execute a task on a cluster @@ -21,7 +21,7 @@ class SkyInterface: down: Tear down a cluster status: Get the status of a cluster autostop: Set the autostop of a cluster - + Example: >>> sky_interface = SkyInterface() >>> job_id = sky_interface.launch("task", "cluster_name") @@ -42,7 +42,7 @@ def launch(self, task, cluster_name=None, **kwargs): """Launch a task on a cluster Args: - task (_type_): _description_ + task (str): code to execute on the cluster cluster_name (_type_, optional): _description_. Defaults to None. Returns: @@ -80,7 +80,7 @@ def stop(self, cluster_name, **kwargs): """Stop a cluster Args: - cluster_name (_type_): _description_ + cluster_name (str): name of the cluster to stop """ try: sky.stop(cluster_name, **kwargs) @@ -91,7 +91,7 @@ def start(self, cluster_name, **kwargs): """start a cluster Args: - cluster_name (_type_): _description_ + cluster_name (str): name of the cluster to start """ try: sky.start(cluster_name, **kwargs) @@ -102,7 +102,7 @@ def down(self, cluster_name, **kwargs): """Down a cluster Args: - cluster_name (_type_): _description_ + cluster_name (str): name of the cluster to tear down """ try: sky.down(cluster_name, **kwargs) @@ -115,7 +115,7 @@ def status(self, **kwargs): """Save a cluster Returns: - _type_: _description_ + r: the status of the cluster """ try: return sky.status(**kwargs) @@ -126,7 +126,7 @@ def autostop(self, cluster_name, **kwargs): """Autostop a cluster Args: - cluster_name (_type_): _description_ + cluster_name (str): name of the cluster to autostop """ try: sky.autostop(cluster_name, **kwargs) diff --git a/tests/test_func_wrapper.py b/tests/test_func_wrapper.py index 2958af2..de843b7 100644 --- a/tests/test_func_wrapper.py +++ b/tests/test_func_wrapper.py @@ -169,3 +169,39 @@ def test_middleware_endpoint(): response = client.get("/test_middleware") assert response.status_code == 200 assert response.headers["X-Custom-Header"] == "Test" + + + +def test_add_endpoints(func_api_wrapper): + endpoints = [ + ("/test1", "get", lambda: {"message": "test1"}), + ("/test2", "post", lambda: {"message": "test2"}), + ] + func_api_wrapper.add_endpoints(endpoints) + + client = TestClient(func_api_wrapper.app) + response = client.get("/test1") + assert response.status_code == 200 + assert response.json() == {"message": "test1"} + + response = client.post("/test2") + assert response.status_code == 200 + assert response.json() == {"message": "test2"} + +def test_add_endpoints_invalid_method(func_api_wrapper): + endpoints = [ + ("/test_invalid", "invalid", lambda: {"message": "test_invalid"}), + ] + with pytest.raises(ValueError): + func_api_wrapper.add_endpoints(endpoints) + +def test_add_endpoints_exception(func_api_wrapper): + endpoints = [ + ("/test_exception", "get", lambda: 1 / 0), + ] + func_api_wrapper.add_endpoints(endpoints) + + client = TestClient(func_api_wrapper.app) + response = client.get("/test_exception") + assert response.status_code == 500 + assert "division by zero" in response.text \ No newline at end of file