From 1ec0782425be79fb6ed99bebb40bcaeccbaeecde Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Thu, 12 Sep 2024 02:04:51 +0545 Subject: [PATCH 01/35] Adds initial function definition --- src/litserve/server.py | 51 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index d4e76000..3706014c 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -27,7 +27,7 @@ from concurrent.futures import ThreadPoolExecutor from contextlib import asynccontextmanager from queue import Empty -from typing import Callable, Dict, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import uvicorn from fastapi import Depends, FastAPI, HTTPException, Request, Response @@ -472,3 +472,52 @@ def setup_auth(self): if LIT_SERVER_API_KEY: return api_key_auth return no_auth + + +def run_all( + servers: List[LitServer], + port: Union[str, int] = 8000, + num_api_servers: Optional[int] = None, + log_level: str = "info", + generate_client_file: bool = True, + api_server_worker_type: Optional[str] = None, + **kwargs, +): + """ + Run multiple LitServers on the same port. + """ + if not servers: + raise ValueError("No servers provided to run_all") + + if any(not isinstance(server, LitServer) for server in servers): + raise ValueError("All elements in the servers list must be instances of LitServer") + + if generate_client_file: + for server in servers: + server.generate_client_file() + + port_msg = f"port must be a value from 1024 to 65535 but got {port}" + try: + port = int(port) + except ValueError: + raise ValueError(port_msg) + if not (1024 <= port <= 65535): + raise ValueError(port_msg) + + if num_api_servers is None: + num_api_servers = sum(len(server.workers) for server in servers) + if num_api_servers < 1: + raise ValueError("num_api_servers must be greater than 0") + + if sys.platform == "win32": + print("Windows does not support forking. Using threads api_server_worker_type will be set to 'thread'") + api_server_worker_type = "thread" + elif api_server_worker_type is None: + api_server_worker_type = "process" + + litserve_workers = [] + # TODO: Implement this fn for multiple servers + + # for server in servers: + # server_manager, server_workers = server.launch_inference_worker(num_api_servers) + # litserve_workers.extend(server_workers) From 88b6b2ff070e6a36ca37c410546ca4e4737635b5 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Thu, 12 Sep 2024 13:37:16 +0545 Subject: [PATCH 02/35] adds run all method --- src/litserve/server.py | 28 +++++++++++++++++++++++----- 1 file changed, 23 insertions(+), 5 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 3706014c..8fb1ba24 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -515,9 +515,27 @@ def run_all( elif api_server_worker_type is None: api_server_worker_type = "process" - litserve_workers = [] - # TODO: Implement this fn for multiple servers + app = servers[0].app + config = uvicorn.Config(app=app, host="0.0.0.0", port=port, log_level=log_level, **kwargs) + sockets = [config.bind_socket()] - # for server in servers: - # server_manager, server_workers = server.launch_inference_worker(num_api_servers) - # litserve_workers.extend(server_workers) + managers, all_workers = [], [] + try: + all_servers = [] + for server in servers: + manager, litserve_workers = server.launch_inference_worker(num_api_servers) + managers.append(manager) + all_workers.extend(litserve_workers) + + _servers = server._start_server(port, num_api_servers, log_level, sockets, api_server_worker_type, **kwargs) + all_servers.extend(_servers) + print(f"Swagger UI is available at http://0.0.0.0:{port}/docs") + for s in all_servers: + s.join() + finally: + print("Shutting down LitServe") + for w in all_workers: + w.terminate() + w.join() + for manager in managers: + manager.shutdown() From 30d2c35c3dcb688ea99f4307aee05ecbe64d1791 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Sep 2024 08:12:01 +0000 Subject: [PATCH 03/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litserve/server.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 8fb1ba24..4ab44343 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -483,9 +483,7 @@ def run_all( api_server_worker_type: Optional[str] = None, **kwargs, ): - """ - Run multiple LitServers on the same port. - """ + """Run multiple LitServers on the same port.""" if not servers: raise ValueError("No servers provided to run_all") From 266b2719aa24ca3e9df6723d3ed91a44db1de7df Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Thu, 12 Sep 2024 14:06:19 +0545 Subject: [PATCH 04/35] adding a test for multi endpoint servers --- tests/test_lit_server_multi_endpoints.py | 30 ++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 tests/test_lit_server_multi_endpoints.py diff --git a/tests/test_lit_server_multi_endpoints.py b/tests/test_lit_server_multi_endpoints.py new file mode 100644 index 00000000..394905e4 --- /dev/null +++ b/tests/test_lit_server_multi_endpoints.py @@ -0,0 +1,30 @@ +import threading + +from httpx import AsyncClient + +from litserve.server import LitServer, run_all + + +async def test_lit_server_with_multi_endpoints(simple_litapi): + server1 = LitServer(simple_litapi, api_path="/predict-1", timeout=10) + server2 = LitServer(simple_litapi, api_path="/predict-2", timeout=10) + servers = [server1, server2] + # TODO: update test to use run_all + + # Run the servers in a separate thread + port = 8000 + # server_thread = threading.Thread( + # target=run_all, args=(servers,), kwargs={"port": port, "num_api_servers": 2, "log_level": "debug"} + # ) + # server_thread.start() + + # async with AsyncClient(base_url=f"http://localhost:{port}") as client: + # # Test server1 endpoint + # response1 = await client.post("/predict-1", json={"input": 1}) + # assert response1.status_code == 200 + # assert response1.json() == {"output": 1} + + # # Test server2 endpoint + # response2 = await client.post("/predict-2", json={"input": 2}) + # assert response2.status_code == 200 + # assert response2.json() == {"output": 2} From 3cf17a07cbc4e601d31b0f77141bd71060ef8065 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Sep 2024 08:24:52 +0000 Subject: [PATCH 05/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_lit_server_multi_endpoints.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_lit_server_multi_endpoints.py b/tests/test_lit_server_multi_endpoints.py index 394905e4..7779d425 100644 --- a/tests/test_lit_server_multi_endpoints.py +++ b/tests/test_lit_server_multi_endpoints.py @@ -1,8 +1,6 @@ -import threading -from httpx import AsyncClient -from litserve.server import LitServer, run_all +from litserve.server import LitServer async def test_lit_server_with_multi_endpoints(simple_litapi): From f4b47150c2f098766044c89b6196172a3866af15 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Thu, 12 Sep 2024 14:35:26 +0545 Subject: [PATCH 06/35] fixes multi client generation --- src/litserve/server.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 4ab44343..c34d3e84 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -491,8 +491,7 @@ def run_all( raise ValueError("All elements in the servers list must be instances of LitServer") if generate_client_file: - for server in servers: - server.generate_client_file() + servers[0].generate_client_file() port_msg = f"port must be a value from 1024 to 65535 but got {port}" try: From ab909593f40aa77107ad7b06445851c3fe2d5150 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Sep 2024 08:51:23 +0000 Subject: [PATCH 07/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_lit_server_multi_endpoints.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_lit_server_multi_endpoints.py b/tests/test_lit_server_multi_endpoints.py index 7779d425..ba4ddd24 100644 --- a/tests/test_lit_server_multi_endpoints.py +++ b/tests/test_lit_server_multi_endpoints.py @@ -1,5 +1,3 @@ - - from litserve.server import LitServer From e3ee2b9763c0be85cf275db5228d7c5a168c4c17 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 13 Sep 2024 14:32:13 +0545 Subject: [PATCH 08/35] testing mounting of apps --- src/litserve/server.py | 26 +++++++++++++++++++++++--- 1 file changed, 23 insertions(+), 3 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index c34d3e84..c78c6e03 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -491,7 +491,7 @@ def run_all( raise ValueError("All elements in the servers list must be instances of LitServer") if generate_client_file: - servers[0].generate_client_file() + LitServer.generate_client_file() port_msg = f"port must be a value from 1024 to 65535 but got {port}" try: @@ -524,8 +524,28 @@ def run_all( managers.append(manager) all_workers.extend(litserve_workers) - _servers = server._start_server(port, num_api_servers, log_level, sockets, api_server_worker_type, **kwargs) - all_servers.extend(_servers) + # Start the servers + for response_queue_id, server in enumerate(servers): + server.app.response_queue_id = response_queue_id + if server.lit_spec: + server.lit_spec.response_queue_id = response_queue_id + + main_app = copy.copy(app) + # mount all other apps + for index, server in enumerate(servers[1:], 1): + main_app.mount(f"/subapp-{index}", server.app) # TODO: Update Mounting Path + + config = uvicorn.Config(app=main_app, host="0.0.0.0", port=port, log_level=log_level, **kwargs) + server = uvicorn.Server(config=config) + if api_server_worker_type == "process": + ctx = mp.get_context("fork") + w = ctx.Process(target=server.run, args=(sockets,)) + elif api_server_worker_type == "thread": + w = threading.Thread(target=server.run, args=(sockets,)) + else: + raise ValueError("Invalid value for api_server_worker_type. Must be 'process' or 'thread'") + w.start() + all_servers.append(w) print(f"Swagger UI is available at http://0.0.0.0:{port}/docs") for s in all_servers: s.join() From e1865df7644a492196ee781d2aa639d16fcaf39f Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Fri, 13 Sep 2024 14:47:52 +0545 Subject: [PATCH 09/35] mounted on root --- src/litserve/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index c78c6e03..cb2144be 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -533,7 +533,7 @@ def run_all( main_app = copy.copy(app) # mount all other apps for index, server in enumerate(servers[1:], 1): - main_app.mount(f"/subapp-{index}", server.app) # TODO: Update Mounting Path + main_app.mount("/", server.app) # TODO: Update Mounting Path config = uvicorn.Config(app=main_app, host="0.0.0.0", port=port, log_level=log_level, **kwargs) server = uvicorn.Server(config=config) From 49865f260ec7d39e30eda6ba4dfd912af1397497 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 15 Sep 2024 22:45:56 +0545 Subject: [PATCH 10/35] Refactor run_all function to support multiple LitServers --- src/litserve/server.py | 101 ++++++++++++++++++++++++----------------- 1 file changed, 60 insertions(+), 41 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index cb2144be..fb9966ed 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -474,20 +474,40 @@ def setup_auth(self): return no_auth +@asynccontextmanager +async def manage_lifespan(app: FastAPI, servers: List[LitServer]): + """ + Context manager to handle the lifespan events of multiple FastAPI servers. + """ + # Start lifespan events for each server + lifespans = [server.lifespan(server.app) for server in servers] + for lifespan in lifespans: + await lifespan.__aenter__() + + try: + yield + finally: + # Exit lifespan events for each server + for lifespan in lifespans: + await lifespan.__aexit__(None, None, None) + + def run_all( - servers: List[LitServer], + litservers: List[LitServer], port: Union[str, int] = 8000, - num_api_servers: Optional[int] = None, + num_api_servers: Optional[int] = 1, log_level: str = "info", generate_client_file: bool = True, api_server_worker_type: Optional[str] = None, **kwargs, ): - """Run multiple LitServers on the same port.""" - if not servers: + """ + Run multiple LitServers on the same port. + """ + if not litservers: raise ValueError("No servers provided to run_all") - if any(not isinstance(server, LitServer) for server in servers): + if any(not isinstance(server, LitServer) for server in litservers): raise ValueError("All elements in the servers list must be instances of LitServer") if generate_client_file: @@ -501,8 +521,6 @@ def run_all( if not (1024 <= port <= 65535): raise ValueError(port_msg) - if num_api_servers is None: - num_api_servers = sum(len(server.workers) for server in servers) if num_api_servers < 1: raise ValueError("num_api_servers must be greater than 0") @@ -512,47 +530,48 @@ def run_all( elif api_server_worker_type is None: api_server_worker_type = "process" - app = servers[0].app + # Create the main FastAPI app + app = FastAPI(lifespan=lambda app: manage_lifespan(app, litservers)) config = uvicorn.Config(app=app, host="0.0.0.0", port=port, log_level=log_level, **kwargs) sockets = [config.bind_socket()] - managers, all_workers = [], [] + managers, workers = [], [] try: - all_servers = [] - for server in servers: - manager, litserve_workers = server.launch_inference_worker(num_api_servers) + for litserver in litservers: + manager, litserve_workers = litserver.launch_inference_worker(num_api_servers) managers.append(manager) - all_workers.extend(litserve_workers) - - # Start the servers - for response_queue_id, server in enumerate(servers): - server.app.response_queue_id = response_queue_id - if server.lit_spec: - server.lit_spec.response_queue_id = response_queue_id - - main_app = copy.copy(app) - # mount all other apps - for index, server in enumerate(servers[1:], 1): - main_app.mount("/", server.app) # TODO: Update Mounting Path - - config = uvicorn.Config(app=main_app, host="0.0.0.0", port=port, log_level=log_level, **kwargs) - server = uvicorn.Server(config=config) - if api_server_worker_type == "process": - ctx = mp.get_context("fork") - w = ctx.Process(target=server.run, args=(sockets,)) - elif api_server_worker_type == "thread": - w = threading.Thread(target=server.run, args=(sockets,)) - else: - raise ValueError("Invalid value for api_server_worker_type. Must be 'process' or 'thread'") - w.start() - all_servers.append(w) + workers.extend(litserve_workers) + + # include routes from each litserver's app into the main app + app.include_router(litserver.app.router) + + server_processes = [] + for response_queue_id in range(num_api_servers): + for litserver in litservers: + litserver.app.response_queue_id = response_queue_id + if litserver.lit_spec: + litserver.lit_spec.response_queue_id = response_queue_id + + app = copy.copy(app) + config = uvicorn.Config(app=app, host="0.0.0.0", port=port, log_level=log_level, **kwargs) + uvicorn_server = uvicorn.Server(config=config) + + if api_server_worker_type == "process": + ctx = mp.get_context("fork") + worker = ctx.Process(target=uvicorn_server.run, args=(sockets,)) + elif api_server_worker_type == "thread": + worker = threading.Thread(target=uvicorn_server.run, args=(sockets,)) + else: + raise ValueError("Invalid value for api_server_worker_type. Must be 'process' or 'thread'") + worker.start() + server_processes.append(worker) print(f"Swagger UI is available at http://0.0.0.0:{port}/docs") - for s in all_servers: - s.join() + for process in server_processes: + process.join() finally: print("Shutting down LitServe") - for w in all_workers: - w.terminate() - w.join() + for worker in workers: + worker.terminate() + worker.join() for manager in managers: manager.shutdown() From 368f91c1017089cbf57b38eea8e8b6713d99493a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 15 Sep 2024 17:01:20 +0000 Subject: [PATCH 11/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/litserve/server.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index fb9966ed..a24eb66b 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -476,9 +476,7 @@ def setup_auth(self): @asynccontextmanager async def manage_lifespan(app: FastAPI, servers: List[LitServer]): - """ - Context manager to handle the lifespan events of multiple FastAPI servers. - """ + """Context manager to handle the lifespan events of multiple FastAPI servers.""" # Start lifespan events for each server lifespans = [server.lifespan(server.app) for server in servers] for lifespan in lifespans: @@ -501,9 +499,7 @@ def run_all( api_server_worker_type: Optional[str] = None, **kwargs, ): - """ - Run multiple LitServers on the same port. - """ + """Run multiple LitServers on the same port.""" if not litservers: raise ValueError("No servers provided to run_all") From 885ab06972d64e61707938ad74830110076143db Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 15 Sep 2024 23:00:48 +0545 Subject: [PATCH 12/35] adds test for the manage_lifespan --- tests/test_lit_server.py | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index 74dc854a..a17032ab 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -26,12 +26,12 @@ from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware from starlette.middleware.trustedhost import TrustedHostMiddleware -from unittest.mock import patch, MagicMock +from unittest.mock import patch, MagicMock, AsyncMock import pytest from litserve.connector import _Connector -from litserve.server import LitServer +from litserve.server import LitServer, manage_lifespan import litserve as ls from fastapi.testclient import TestClient from starlette.types import ASGIApp @@ -429,3 +429,28 @@ def test_middlewares_inputs(): with pytest.raises(ValueError, match="middlewares must be a list of tuples"): ls.LitServer(ls.test_examples.SimpleLitAPI(), middlewares=(RequestIdMiddleware, {"length": 5})) + + +@pytest.mark.asyncio +@patch("litserve.server.LitServer") +async def test_manage_lifespan(mock_litserver): + # Mock the LitServer instance + mock_server_instance = MagicMock(spec=LitServer) + mock_server_instance.app = MagicMock() + + # Create an async context manager mock for the lifespan method + mock_lifespan_cm = MagicMock() + mock_lifespan_cm.__aenter__ = AsyncMock() + mock_lifespan_cm.__aexit__ = AsyncMock() + mock_server_instance.lifespan.return_value = mock_lifespan_cm + + mock_litserver.return_value = mock_server_instance + + servers = [mock_server_instance, mock_server_instance] + + async with manage_lifespan(None, servers): + # Assertions to ensure the function behaves as expected + assert mock_lifespan_cm.__aenter__.call_count == 2 + + # Ensure the __aexit__ method was called + assert mock_lifespan_cm.__aexit__.call_count == 2 From 7ba3ab72e0fb13f7b553b531647011c5096cf492 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 15 Sep 2024 23:01:24 +0545 Subject: [PATCH 13/35] ref: format imports --- tests/test_lit_server.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index a17032ab..19c0f3e1 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -15,27 +15,25 @@ import pickle import re import sys +from unittest.mock import AsyncMock, MagicMock, patch -from asgi_lifespan import LifespanManager -from litserve import LitAPI -from fastapi import Request, Response, HTTPException +import pytest import torch import torch.nn as nn +from asgi_lifespan import LifespanManager +from fastapi import HTTPException, Request, Response +from fastapi.testclient import TestClient from httpx import AsyncClient -from litserve.utils import wrap_litserve_start +from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.httpsredirect import HTTPSRedirectMiddleware from starlette.middleware.trustedhost import TrustedHostMiddleware +from starlette.types import ASGIApp -from unittest.mock import patch, MagicMock, AsyncMock -import pytest - +import litserve as ls +from litserve import LitAPI from litserve.connector import _Connector - from litserve.server import LitServer, manage_lifespan -import litserve as ls -from fastapi.testclient import TestClient -from starlette.types import ASGIApp -from starlette.middleware.base import BaseHTTPMiddleware +from litserve.utils import wrap_litserve_start def test_index(sync_testclient): From d7cb562c09d40898d0be53ce1bb9693964148917 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 15 Sep 2024 23:46:20 +0545 Subject: [PATCH 14/35] removed the extra test file --- tests/test_lit_server_multi_endpoints.py | 26 ------------------------ 1 file changed, 26 deletions(-) delete mode 100644 tests/test_lit_server_multi_endpoints.py diff --git a/tests/test_lit_server_multi_endpoints.py b/tests/test_lit_server_multi_endpoints.py deleted file mode 100644 index ba4ddd24..00000000 --- a/tests/test_lit_server_multi_endpoints.py +++ /dev/null @@ -1,26 +0,0 @@ -from litserve.server import LitServer - - -async def test_lit_server_with_multi_endpoints(simple_litapi): - server1 = LitServer(simple_litapi, api_path="/predict-1", timeout=10) - server2 = LitServer(simple_litapi, api_path="/predict-2", timeout=10) - servers = [server1, server2] - # TODO: update test to use run_all - - # Run the servers in a separate thread - port = 8000 - # server_thread = threading.Thread( - # target=run_all, args=(servers,), kwargs={"port": port, "num_api_servers": 2, "log_level": "debug"} - # ) - # server_thread.start() - - # async with AsyncClient(base_url=f"http://localhost:{port}") as client: - # # Test server1 endpoint - # response1 = await client.post("/predict-1", json={"input": 1}) - # assert response1.status_code == 200 - # assert response1.json() == {"output": 1} - - # # Test server2 endpoint - # response2 = await client.post("/predict-2", json={"input": 2}) - # assert response2.status_code == 200 - # assert response2.json() == {"output": 2} From d502af4c17bd2eb2e2bde6080e2ae4174db40730 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 15 Sep 2024 23:46:35 +0545 Subject: [PATCH 15/35] adds simple server with multi endpoints --- tests/simple_server_with_multi_endpoints.py | 43 +++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 tests/simple_server_with_multi_endpoints.py diff --git a/tests/simple_server_with_multi_endpoints.py b/tests/simple_server_with_multi_endpoints.py new file mode 100644 index 00000000..0f1c61e9 --- /dev/null +++ b/tests/simple_server_with_multi_endpoints.py @@ -0,0 +1,43 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from litserve.server import LitServer, run_all +from litserve.test_examples import SimpleLitAPI + + +class SimpleLitAPI1(SimpleLitAPI): + def setup(self, device): + self.model = lambda x: x**1 + + +class SimpleLitAPI2(SimpleLitAPI): + def setup(self, device): + self.model = lambda x: x**2 + + +class SimpleLitAPI3(SimpleLitAPI): + def setup(self, device): + self.model = lambda x: x**3 + + +class SimpleLitAPI4(SimpleLitAPI): + def setup(self, device): + self.model = lambda x: x**4 + + +if __name__ == "__main__": + server1 = LitServer(SimpleLitAPI1(), accelerator="cpu", devices=1, timeout=10, api_path="/predict-1") + server2 = LitServer(SimpleLitAPI2(), accelerator="cpu", devices=1, timeout=10, api_path="/predict-2") + server3 = LitServer(SimpleLitAPI3(), accelerator="cpu", devices=1, timeout=10, api_path="/predict-3") + server4 = LitServer(SimpleLitAPI4(), accelerator="cpu", devices=1, timeout=10, api_path="/predict-4") + run_all([server1, server2, server3, server4], port=8000) From 17a37cc6d34a5acb46b3643c1ef6c14f21254194 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 15 Sep 2024 23:46:57 +0545 Subject: [PATCH 16/35] adds e2e test for multi endpoints --- tests/e2e/test_e2e.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index f2618f13..70ecdb38 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -57,6 +57,17 @@ def test_run(): os.remove("client.py") +@e2e_from_file("tests/simple_server_with_multi_endpoints.py") +def test_run_with_multi_endpoints(): + assert os.path.exists("client.py"), f"Expected client file to be created at {os.getcwd()} after starting the server" + for i in range(1, 5): + resp = requests.post(f"http://127.0.0.1:8000/predict-{i}", json={"input": 4.0}, headers=None) + assert resp.status_code == 200, f"Expected response to be 200 but got {resp.status_code}" + assert resp.json() == { + "output": 4.0**i + }, "tests/simple_server_with_multi_endpoints.py didn't return expected output" + + @e2e_from_file("tests/e2e/default_api.py") def test_e2e_default_api(): resp = requests.post("http://127.0.0.1:8000/predict", json={"input": 4.0}, headers=None) From 452a509394b1d084b63ae3fbdb1be0fb31e7904b Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Mon, 16 Sep 2024 00:02:46 +0545 Subject: [PATCH 17/35] Refactor run_all function to remove unnecessary check for empty litservers list --- src/litserve/server.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index a24eb66b..765e5157 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -500,8 +500,6 @@ def run_all( **kwargs, ): """Run multiple LitServers on the same port.""" - if not litservers: - raise ValueError("No servers provided to run_all") if any(not isinstance(server, LitServer) for server in litservers): raise ValueError("All elements in the servers list must be instances of LitServer") From 391cefda38f2f63c3028cf5e4ff6887b511d0920 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Mon, 16 Sep 2024 00:05:25 +0545 Subject: [PATCH 18/35] adds test to runnall litservers --- tests/test_lit_server.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index 19c0f3e1..07f6742f 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -32,7 +32,7 @@ import litserve as ls from litserve import LitAPI from litserve.connector import _Connector -from litserve.server import LitServer, manage_lifespan +from litserve.server import LitServer, manage_lifespan, run_all from litserve.utils import wrap_litserve_start @@ -452,3 +452,24 @@ async def test_manage_lifespan(mock_litserver): # Ensure the __aexit__ method was called assert mock_lifespan_cm.__aexit__.call_count == 2 + + +@patch("litserve.server.uvicorn") +def test_run_all_litservers(mock_uvicorn): + server1 = LitServer(SimpleLitAPI(), api_path="/predict-1") + server2 = LitServer(SimpleLitAPI(), api_path="/predict-2") + + with pytest.raises(ValueError, match="All elements in the servers list must be instances of LitServer"): + run_all([server1, "server2"]) + + with pytest.raises(ValueError, match="port must be a value from 1024 to 65535 but got"): + run_all([server1, server2], port="invalid port") + + with pytest.raises(ValueError, match="port must be a value from 1024 to 65535 but got"): + run_all([server1, server2], port=65536) + + run_all([server1, server2], port=8000) + mock_uvicorn.Config.assert_called() + mock_uvicorn.reset_mock() + run_all([server1, server2], port="8001") + mock_uvicorn.Config.assert_called() From ea0ac77445b77abb5b9837b248d2aee1c337daa1 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Mon, 16 Sep 2024 00:23:42 +0545 Subject: [PATCH 19/35] Refactor test_e2e_with_multi_endpoints function name --- tests/e2e/test_e2e.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 70ecdb38..bab0c115 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -58,7 +58,7 @@ def test_run(): @e2e_from_file("tests/simple_server_with_multi_endpoints.py") -def test_run_with_multi_endpoints(): +def test_e2e_with_multi_endpoints(): assert os.path.exists("client.py"), f"Expected client file to be created at {os.getcwd()} after starting the server" for i in range(1, 5): resp = requests.post(f"http://127.0.0.1:8000/predict-{i}", json={"input": 4.0}, headers=None) From b82733e59a671cc5f0561ed2736900db57a6595d Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Mon, 16 Sep 2024 00:36:56 +0545 Subject: [PATCH 20/35] Refactor test_e2e_with_multi_endpoints function name --- tests/e2e/test_e2e.py | 2 +- tests/simple_server_with_multi_endpoints.py | 14 +------------- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index bab0c115..40a13ba6 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -60,7 +60,7 @@ def test_run(): @e2e_from_file("tests/simple_server_with_multi_endpoints.py") def test_e2e_with_multi_endpoints(): assert os.path.exists("client.py"), f"Expected client file to be created at {os.getcwd()} after starting the server" - for i in range(1, 5): + for i in range(1, 3): resp = requests.post(f"http://127.0.0.1:8000/predict-{i}", json={"input": 4.0}, headers=None) assert resp.status_code == 200, f"Expected response to be 200 but got {resp.status_code}" assert resp.json() == { diff --git a/tests/simple_server_with_multi_endpoints.py b/tests/simple_server_with_multi_endpoints.py index 0f1c61e9..cd311af9 100644 --- a/tests/simple_server_with_multi_endpoints.py +++ b/tests/simple_server_with_multi_endpoints.py @@ -25,19 +25,7 @@ def setup(self, device): self.model = lambda x: x**2 -class SimpleLitAPI3(SimpleLitAPI): - def setup(self, device): - self.model = lambda x: x**3 - - -class SimpleLitAPI4(SimpleLitAPI): - def setup(self, device): - self.model = lambda x: x**4 - - if __name__ == "__main__": server1 = LitServer(SimpleLitAPI1(), accelerator="cpu", devices=1, timeout=10, api_path="/predict-1") server2 = LitServer(SimpleLitAPI2(), accelerator="cpu", devices=1, timeout=10, api_path="/predict-2") - server3 = LitServer(SimpleLitAPI3(), accelerator="cpu", devices=1, timeout=10, api_path="/predict-3") - server4 = LitServer(SimpleLitAPI4(), accelerator="cpu", devices=1, timeout=10, api_path="/predict-4") - run_all([server1, server2, server3, server4], port=8000) + run_all([server1, server2], port=8000) From f8491a1072400d634a30950ab74af15e82c7ec22 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Mon, 16 Sep 2024 00:42:01 +0545 Subject: [PATCH 21/35] update tests to include more litservers --- tests/e2e/test_e2e.py | 2 +- tests/simple_server_with_multi_endpoints.py | 18 +++++++++++++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index 40a13ba6..bab0c115 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -60,7 +60,7 @@ def test_run(): @e2e_from_file("tests/simple_server_with_multi_endpoints.py") def test_e2e_with_multi_endpoints(): assert os.path.exists("client.py"), f"Expected client file to be created at {os.getcwd()} after starting the server" - for i in range(1, 3): + for i in range(1, 5): resp = requests.post(f"http://127.0.0.1:8000/predict-{i}", json={"input": 4.0}, headers=None) assert resp.status_code == 200, f"Expected response to be 200 but got {resp.status_code}" assert resp.json() == { diff --git a/tests/simple_server_with_multi_endpoints.py b/tests/simple_server_with_multi_endpoints.py index cd311af9..25455732 100644 --- a/tests/simple_server_with_multi_endpoints.py +++ b/tests/simple_server_with_multi_endpoints.py @@ -25,7 +25,19 @@ def setup(self, device): self.model = lambda x: x**2 +class SimpleLitAPI3(SimpleLitAPI): + def setup(self, device): + self.model = lambda x: x**3 + + +class SimpleLitAPI4(SimpleLitAPI): + def setup(self, device): + self.model = lambda x: x**4 + + if __name__ == "__main__": - server1 = LitServer(SimpleLitAPI1(), accelerator="cpu", devices=1, timeout=10, api_path="/predict-1") - server2 = LitServer(SimpleLitAPI2(), accelerator="cpu", devices=1, timeout=10, api_path="/predict-2") - run_all([server1, server2], port=8000) + server1 = LitServer(SimpleLitAPI1(), api_path="/predict-1") + server2 = LitServer(SimpleLitAPI2(), api_path="/predict-2") + server3 = LitServer(SimpleLitAPI3(), api_path="/predict-3") + server4 = LitServer(SimpleLitAPI4(), api_path="/predict-4") + run_all([server1, server2, server3, server4], port=8000) From e061c88f75158c16d295ae1477aa54c2a0f1fb37 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Tue, 17 Sep 2024 11:43:54 +0545 Subject: [PATCH 22/35] Refactor manage_lifespan to multi_server_lifespan --- src/litserve/server.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 765e5157..91d19096 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -25,7 +25,7 @@ import warnings from collections import deque from concurrent.futures import ThreadPoolExecutor -from contextlib import asynccontextmanager +from contextlib import AsyncExitStack, asynccontextmanager from queue import Empty from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union @@ -475,19 +475,13 @@ def setup_auth(self): @asynccontextmanager -async def manage_lifespan(app: FastAPI, servers: List[LitServer]): +async def multi_server_lifespan(app: FastAPI, servers: List[LitServer]): """Context manager to handle the lifespan events of multiple FastAPI servers.""" # Start lifespan events for each server - lifespans = [server.lifespan(server.app) for server in servers] - for lifespan in lifespans: - await lifespan.__aenter__() - - try: + async with AsyncExitStack() as stack: + for server in servers: + await stack.enter_async_context(server.lifespan(server.app)) yield - finally: - # Exit lifespan events for each server - for lifespan in lifespans: - await lifespan.__aexit__(None, None, None) def run_all( @@ -525,7 +519,7 @@ def run_all( api_server_worker_type = "process" # Create the main FastAPI app - app = FastAPI(lifespan=lambda app: manage_lifespan(app, litservers)) + app = FastAPI(lifespan=lambda app: multi_server_lifespan(app, litservers)) config = uvicorn.Config(app=app, host="0.0.0.0", port=port, log_level=log_level, **kwargs) sockets = [config.bind_socket()] From 1de6935ebdcd3af50e97a81c9cf122d5a01e30e2 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Tue, 17 Sep 2024 12:06:36 +0545 Subject: [PATCH 23/35] Refactor test multi_server_lifespan --- tests/test_lit_server.py | 31 +++++++++---------------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index 07f6742f..f593fbc1 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -32,7 +32,7 @@ import litserve as ls from litserve import LitAPI from litserve.connector import _Connector -from litserve.server import LitServer, manage_lifespan, run_all +from litserve.server import LitServer, multi_server_lifespan, run_all from litserve.utils import wrap_litserve_start @@ -431,27 +431,14 @@ def test_middlewares_inputs(): @pytest.mark.asyncio @patch("litserve.server.LitServer") -async def test_manage_lifespan(mock_litserver): - # Mock the LitServer instance - mock_server_instance = MagicMock(spec=LitServer) - mock_server_instance.app = MagicMock() - - # Create an async context manager mock for the lifespan method - mock_lifespan_cm = MagicMock() - mock_lifespan_cm.__aenter__ = AsyncMock() - mock_lifespan_cm.__aexit__ = AsyncMock() - mock_server_instance.lifespan.return_value = mock_lifespan_cm - - mock_litserver.return_value = mock_server_instance - - servers = [mock_server_instance, mock_server_instance] - - async with manage_lifespan(None, servers): - # Assertions to ensure the function behaves as expected - assert mock_lifespan_cm.__aenter__.call_count == 2 - - # Ensure the __aexit__ method was called - assert mock_lifespan_cm.__aexit__.call_count == 2 +async def test_multi_server_lifespan(mock_litserver): + # List of servers + servers = [mock_litserver, mock_litserver] + # Use the async context manager + async with multi_server_lifespan(MagicMock(), servers): + # Check if the lifespan method was called for each server + assert mock_litserver.lifespan.call_count == 2 + assert mock_litserver.lifespan.return_value.__aexit__.call_count == 2 @patch("litserve.server.uvicorn") From 0697222de00632c25cdfaa34c3a8784aec073f45 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Tue, 17 Sep 2024 12:12:58 +0545 Subject: [PATCH 24/35] refactor classnames --- ...ulti_endpoints.py => multiple_litserver.py} | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) rename tests/{simple_server_with_multi_endpoints.py => multiple_litserver.py} (64%) diff --git a/tests/simple_server_with_multi_endpoints.py b/tests/multiple_litserver.py similarity index 64% rename from tests/simple_server_with_multi_endpoints.py rename to tests/multiple_litserver.py index 25455732..edd3f937 100644 --- a/tests/simple_server_with_multi_endpoints.py +++ b/tests/multiple_litserver.py @@ -15,29 +15,29 @@ from litserve.test_examples import SimpleLitAPI -class SimpleLitAPI1(SimpleLitAPI): +class MultipleLitServerAPI1(SimpleLitAPI): def setup(self, device): self.model = lambda x: x**1 -class SimpleLitAPI2(SimpleLitAPI): +class MultipleLitServerAPI2(SimpleLitAPI): def setup(self, device): self.model = lambda x: x**2 -class SimpleLitAPI3(SimpleLitAPI): +class MultipleLitServerAPI3(SimpleLitAPI): def setup(self, device): self.model = lambda x: x**3 -class SimpleLitAPI4(SimpleLitAPI): +class MultipleLitServerAPI4(SimpleLitAPI): def setup(self, device): self.model = lambda x: x**4 if __name__ == "__main__": - server1 = LitServer(SimpleLitAPI1(), api_path="/predict-1") - server2 = LitServer(SimpleLitAPI2(), api_path="/predict-2") - server3 = LitServer(SimpleLitAPI3(), api_path="/predict-3") - server4 = LitServer(SimpleLitAPI4(), api_path="/predict-4") - run_all([server1, server2, server3, server4], port=8000) + server1 = LitServer(MultipleLitServerAPI1(), api_path="/predict-1") + server2 = LitServer(MultipleLitServerAPI2(), api_path="/predict-2") + server3 = LitServer(MultipleLitServerAPI3(), api_path="/predict-3") + server4 = LitServer(MultipleLitServerAPI4(), api_path="/predict-4") + run_all([server1, server2, server3, server4], port=8000, num_api_servers=2) From 4ac180d4801d9ee205bb975cf78ced3456860566 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Tue, 17 Sep 2024 12:13:08 +0545 Subject: [PATCH 25/35] Refactor test_e2e_combined_multiple_litserver to use multiple_litserver.py --- tests/e2e/test_e2e.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/e2e/test_e2e.py b/tests/e2e/test_e2e.py index bab0c115..ea9360bf 100644 --- a/tests/e2e/test_e2e.py +++ b/tests/e2e/test_e2e.py @@ -57,8 +57,8 @@ def test_run(): os.remove("client.py") -@e2e_from_file("tests/simple_server_with_multi_endpoints.py") -def test_e2e_with_multi_endpoints(): +@e2e_from_file("tests/multiple_litserver.py") +def test_e2e_combined_multiple_litserver(): assert os.path.exists("client.py"), f"Expected client file to be created at {os.getcwd()} after starting the server" for i in range(1, 5): resp = requests.post(f"http://127.0.0.1:8000/predict-{i}", json={"input": 4.0}, headers=None) From 3e9dc6e8b9e0fa15ae245987588e02d351a43d00 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Tue, 17 Sep 2024 12:14:12 +0545 Subject: [PATCH 26/35] adds default queue id for the main app --- src/litserve/server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/litserve/server.py b/src/litserve/server.py index 91d19096..9ae33fa9 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -478,6 +478,7 @@ def setup_auth(self): async def multi_server_lifespan(app: FastAPI, servers: List[LitServer]): """Context manager to handle the lifespan events of multiple FastAPI servers.""" # Start lifespan events for each server + app.response_queue_id = 0 async with AsyncExitStack() as stack: for server in servers: await stack.enter_async_context(server.lifespan(server.app)) From def24f86d58c9ee04af34d3fc0d2eabb21846f42 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 17 Sep 2024 06:29:26 +0000 Subject: [PATCH 27/35] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_lit_server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index f593fbc1..0707cf0e 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -15,7 +15,7 @@ import pickle import re import sys -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest import torch From d19e1aeec9800e730ac6d7c73c10d23e308de327 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Tue, 17 Sep 2024 12:21:03 +0545 Subject: [PATCH 28/35] rm num api server --- tests/multiple_litserver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/multiple_litserver.py b/tests/multiple_litserver.py index edd3f937..b10d267e 100644 --- a/tests/multiple_litserver.py +++ b/tests/multiple_litserver.py @@ -40,4 +40,4 @@ def setup(self, device): server2 = LitServer(MultipleLitServerAPI2(), api_path="/predict-2") server3 = LitServer(MultipleLitServerAPI3(), api_path="/predict-3") server4 = LitServer(MultipleLitServerAPI4(), api_path="/predict-4") - run_all([server1, server2, server3, server4], port=8000, num_api_servers=2) + run_all([server1, server2, server3, server4], port=8000) From 3705505e66b99d8ab79127a7ca941c5f2b701481 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Tue, 17 Sep 2024 15:09:48 +0530 Subject: [PATCH 29/35] Update server.py --- src/litserve/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 9ae33fa9..0459c75f 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -257,7 +257,7 @@ async def lifespan(self, app: FastAPI): "the LitServer class to initialize the response queues." ) - response_queue = self.response_queues[app.response_queue_id] + response_queue = self.response_queues[self.app.response_queue_id] response_executor = ThreadPoolExecutor(max_workers=len(self.devices * self.workers_per_device)) future = response_queue_to_buffer(response_queue, self.response_buffer, self.stream, response_executor) task = loop.create_task(future) From fa8dd7dcd95b6cfde04352763d6c97b9b5cf48b9 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Tue, 17 Sep 2024 15:10:07 +0530 Subject: [PATCH 30/35] fix --- src/litserve/server.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 0459c75f..7a49c4ea 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -478,7 +478,6 @@ def setup_auth(self): async def multi_server_lifespan(app: FastAPI, servers: List[LitServer]): """Context manager to handle the lifespan events of multiple FastAPI servers.""" # Start lifespan events for each server - app.response_queue_id = 0 async with AsyncExitStack() as stack: for server in servers: await stack.enter_async_context(server.lifespan(server.app)) From 724d80ba9e9339dbadeab4c25ccce7f087955799 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Wed, 18 Sep 2024 00:52:37 +0545 Subject: [PATCH 31/35] Refactor server.py to use 'servers' instead of 'litservers' for consistency --- src/litserve/server.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 7a49c4ea..68e34ee0 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -485,7 +485,7 @@ async def multi_server_lifespan(app: FastAPI, servers: List[LitServer]): def run_all( - litservers: List[LitServer], + servers: List[LitServer], port: Union[str, int] = 8000, num_api_servers: Optional[int] = 1, log_level: str = "info", @@ -495,7 +495,7 @@ def run_all( ): """Run multiple LitServers on the same port.""" - if any(not isinstance(server, LitServer) for server in litservers): + if any(not isinstance(server, LitServer) for server in servers): raise ValueError("All elements in the servers list must be instances of LitServer") if generate_client_file: @@ -519,13 +519,13 @@ def run_all( api_server_worker_type = "process" # Create the main FastAPI app - app = FastAPI(lifespan=lambda app: multi_server_lifespan(app, litservers)) + app = FastAPI(lifespan=lambda app: multi_server_lifespan(app, servers)) config = uvicorn.Config(app=app, host="0.0.0.0", port=port, log_level=log_level, **kwargs) sockets = [config.bind_socket()] managers, workers = [], [] try: - for litserver in litservers: + for litserver in servers: manager, litserve_workers = litserver.launch_inference_worker(num_api_servers) managers.append(manager) workers.extend(litserve_workers) @@ -535,7 +535,7 @@ def run_all( server_processes = [] for response_queue_id in range(num_api_servers): - for litserver in litservers: + for litserver in servers: litserver.app.response_queue_id = response_queue_id if litserver.lit_spec: litserver.lit_spec.response_queue_id = response_queue_id From 2ac6388c55e2c00495b86f414ad82b7c34d660fe Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Wed, 18 Sep 2024 01:00:41 +0545 Subject: [PATCH 32/35] added more test cases to increase the coverage --- tests/test_lit_server.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index 0707cf0e..9d8e0805 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -33,6 +33,7 @@ from litserve import LitAPI from litserve.connector import _Connector from litserve.server import LitServer, multi_server_lifespan, run_all +from litserve.test_examples.openai_spec_example import TestAPI from litserve.utils import wrap_litserve_start @@ -445,6 +446,7 @@ async def test_multi_server_lifespan(mock_litserver): def test_run_all_litservers(mock_uvicorn): server1 = LitServer(SimpleLitAPI(), api_path="/predict-1") server2 = LitServer(SimpleLitAPI(), api_path="/predict-2") + server3 = LitServer(TestAPI(), spec=ls.OpenAISpec()) with pytest.raises(ValueError, match="All elements in the servers list must be instances of LitServer"): run_all([server1, "server2"]) @@ -455,8 +457,14 @@ def test_run_all_litservers(mock_uvicorn): with pytest.raises(ValueError, match="port must be a value from 1024 to 65535 but got"): run_all([server1, server2], port=65536) - run_all([server1, server2], port=8000) + with pytest.raises(ValueError, match="num_api_servers must be greater than 0"): + run_all([server1, server2], num_api_servers=0) + + with pytest.raises(ValueError, match="Must be 'process' or 'thread'"): + run_all([server1, server2], api_server_worker_type="invalid") + + run_all([server1, server2, server3], port=8000) mock_uvicorn.Config.assert_called() mock_uvicorn.reset_mock() - run_all([server1, server2], port="8001") + run_all([server1, server2, server3], port="8001") mock_uvicorn.Config.assert_called() From a09f5409de69f24096a2c40a177fc95f5505cf3c Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Wed, 18 Sep 2024 01:09:42 +0545 Subject: [PATCH 33/35] updated test --- tests/test_lit_server.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index 9d8e0805..d85c05c6 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -460,9 +460,6 @@ def test_run_all_litservers(mock_uvicorn): with pytest.raises(ValueError, match="num_api_servers must be greater than 0"): run_all([server1, server2], num_api_servers=0) - with pytest.raises(ValueError, match="Must be 'process' or 'thread'"): - run_all([server1, server2], api_server_worker_type="invalid") - run_all([server1, server2, server3], port=8000) mock_uvicorn.Config.assert_called() mock_uvicorn.reset_mock() From e5db967a29fe6aacfc9c57950543b0f9d86d921a Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Wed, 18 Sep 2024 09:40:30 +0545 Subject: [PATCH 34/35] Refactor server.py to use 'servers' instead of 'litservers' for consistency and minor cleanup --- src/litserve/server.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index 68e34ee0..a11656df 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -523,22 +523,22 @@ def run_all( config = uvicorn.Config(app=app, host="0.0.0.0", port=port, log_level=log_level, **kwargs) sockets = [config.bind_socket()] - managers, workers = [], [] + managers, inference_workers = [], [] try: - for litserver in servers: - manager, litserve_workers = litserver.launch_inference_worker(num_api_servers) + for server in servers: + manager, workers = server.launch_inference_worker(num_api_servers) managers.append(manager) - workers.extend(litserve_workers) + inference_workers.extend(workers) # include routes from each litserver's app into the main app - app.include_router(litserver.app.router) + app.include_router(server.app.router) server_processes = [] for response_queue_id in range(num_api_servers): - for litserver in servers: - litserver.app.response_queue_id = response_queue_id - if litserver.lit_spec: - litserver.lit_spec.response_queue_id = response_queue_id + for server in servers: + server.app.response_queue_id = response_queue_id + if server.lit_spec: + server.lit_spec.response_queue_id = response_queue_id app = copy.copy(app) config = uvicorn.Config(app=app, host="0.0.0.0", port=port, log_level=log_level, **kwargs) @@ -558,7 +558,7 @@ def run_all( process.join() finally: print("Shutting down LitServe") - for worker in workers: + for worker in inference_workers: worker.terminate() worker.join() for manager in managers: From 33e3b12f7cbb3ea62e46a29bf032346460fd1b6e Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Wed, 18 Sep 2024 15:38:28 +0545 Subject: [PATCH 35/35] update --- src/litserve/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/litserve/server.py b/src/litserve/server.py index f756f4fb..feb44e26 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -251,7 +251,7 @@ async def lifespan(self, app: FastAPI): "the LitServer class to initialize the response queues." ) - response_queue = self.response_queues[app.response_queue_id] + response_queue = self.response_queues[self.app.response_queue_id] response_executor = ThreadPoolExecutor(max_workers=len(self.inference_workers)) future = response_queue_to_buffer(response_queue, self.response_buffer, self.stream, response_executor) task = loop.create_task(future)