From abb7f9ab10e9ffb9e3f6c14255c595e74021483d Mon Sep 17 00:00:00 2001 From: Sidharth Shanker Date: Fri, 27 Sep 2024 05:41:31 +0000 Subject: [PATCH 1/7] Cancellation should happen before any response sent back. --- truss/tests/test_model_inference.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/truss/tests/test_model_inference.py b/truss/tests/test_model_inference.py index f499e9f3c..9aaa73dd7 100644 --- a/truss/tests/test_model_inference.py +++ b/truss/tests/test_model_inference.py @@ -1138,6 +1138,35 @@ async def predict(self, inputs, request: fastapi.Request): assert "Cancelled (during gen)." in container.logs() +@pytest.mark.integration +def test_async_streaming_with_cancellation_before_generation(): + model = """ + import fastapi, asyncio, logging + + class Model: + async def predict(self, inputs, request: fastapi.Request): + await asyncio.sleep(2) + if await request.is_disconnected(): + logging.warning("Cancelled (before gen).") + return + return "Done" + """ + with ensure_kill_all(), temp_truss(model, "") as tr: + container = tr.docker_run( + local_port=8090, detach=True, wait_for_server_ready=True + ) + # For hard cancellation we need to use httpx, requests' timeouts don't work. + with pytest.raises(httpx.ReadTimeout): + with httpx.Client( + timeout=httpx.Timeout(1.0, connect=1.0, read=1.0) + ) as client: + response = client.post(PREDICT_URL, json={}, timeout=1.0) + response.raise_for_status() + + time.sleep(2) # Wait a bit to get all logs. + assert "Cancelled (before gen)." in container.logs() + + @pytest.mark.integration def test_limit_concurrency_with_sse(): # It seems that the "builtin" functionality of the FastAPI server already buffers From 7cb94d5b4c431dd995ba89a7f6ed8e9a557e1078 Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Mon, 30 Sep 2024 15:13:04 -0700 Subject: [PATCH 2/7] Use pure ASGI middleware for termination handling. Skip disconnected requests always in predict method --- .../common/termination_handler_middleware.py | 85 ++++--- truss/templates/server/truss_server.py | 14 +- .../test_termination_handler_middleware.py | 236 ++++++++++++------ truss/tests/test_model_inference.py | 4 + 4 files changed, 221 insertions(+), 118 deletions(-) diff --git a/truss/templates/server/common/termination_handler_middleware.py b/truss/templates/server/common/termination_handler_middleware.py index f0a8f6e43..2c09fd7db 100644 --- a/truss/templates/server/common/termination_handler_middleware.py +++ b/truss/templates/server/common/termination_handler_middleware.py @@ -1,64 +1,71 @@ import asyncio +import logging import signal from typing import Callable -from fastapi import Request +from starlette.types import ASGIApp, Receive, Scope, Send # This is to allow the last request's response to finish handling. There may be more # middlewares that the response goes through, and then there's the time for the bytes -# to be pushed to the caller. +# to be sent to the caller. DEFAULT_TERM_DELAY_SECS = 5.0 class TerminationHandlerMiddleware: """ - This middleware allows for swiftly and safely terminating the server. It - listens to a set of termination signals. On receiving such a signal, it - first informs on the on_stop callback, then waits for currently executing - requests to finish, before informing on the on_term callback. - - Stop means that the process to stop the server has started. As soon as - outstading requests go to zero after this, on_term will be called. - - Term means that this is the right time to terminate the server process, no - outstanding requests at this point. + Implements https://www.starlette.io/middleware/#pure-asgi-middleware - The caller would typically handle on_stop by stop sending more requests to - the FastApi server. And on_term by exiting the server process. + This middleware allows for swiftly and safely terminating the server. It + listens to a set of termination signals. On receiving such a signal, it terminates + immediately if there are no outstanding requests and otherwise "marks" the server + to be terminated when all outstanding requests are done. """ def __init__( self, - on_stop: Callable[[], None], - on_term: Callable[[], None], + app: ASGIApp, + on_termination: Callable[[], None], termination_delay_secs: float = DEFAULT_TERM_DELAY_SECS, ): - self._outstanding_request_count = 0 - self._on_stop = on_stop - self._on_term = on_term + self._app = app + self._outstanding_requests_semaphore = asyncio.Semaphore(0) + self._on_termination = on_termination self._termination_delay_secs = termination_delay_secs - self._stopped = False + self._should_terminate_soon = False + + loop = asyncio.get_event_loop() for sig in [signal.SIGINT, signal.SIGTERM, signal.SIGQUIT]: - signal.signal(sig, self._stop) + loop.add_signal_handler(sig, self._handle_stop_signal) - async def __call__(self, request: Request, call_next): - self._outstanding_request_count += 1 - try: - response = await call_next(request) - finally: - self._outstanding_request_count -= 1 - if self._outstanding_request_count == 0 and self._stopped: - # There's a delay in term to allow some time for current - # response flow to finish. - asyncio.create_task(self._term()) - return response + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + if scope["type"] == "http": + self._outstanding_requests_semaphore.release() # Increment. + try: + await self._app(scope, receive, send) + finally: + await self._outstanding_requests_semaphore.acquire() # Decrement. + # Check if it's time to terminate after all requests finish + if ( + self._should_terminate_soon + and self._outstanding_requests_semaphore.locked() + ): + logging.info("Termination after finishing outstanding requests.") + # Run in background, to not block the current request handling. + asyncio.create_task(self._terminate()) + else: + await self._app(scope, receive, send) - def _stop(self, sig, frame): - self._on_stop() - self._stopped = True - if self._outstanding_request_count == 0: - self._on_term() + def _handle_stop_signal(self) -> None: + logging.info("Received termination signal.") + self._should_terminate_soon = True + if self._outstanding_requests_semaphore.locked(): + logging.info("No outstanding requests. Terminate immediately.") + asyncio.create_task(self._terminate()) + else: + logging.info("Will terminate when all requests are processed.") - async def _term(self): + async def _terminate(self) -> None: + logging.info("Sleeping before termination.") await asyncio.sleep(self._termination_delay_secs) - self._on_term() + logging.info("Terminating") + self._on_termination() diff --git a/truss/templates/server/truss_server.py b/truss/templates/server/truss_server.py index 776849f89..e4ffedc70 100644 --- a/truss/templates/server/truss_server.py +++ b/truss/templates/server/truss_server.py @@ -27,7 +27,6 @@ from shared import serialization, util from shared.logging import setup_logging from shared.secrets_resolver import SecretsResolver -from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import ClientDisconnect from starlette.responses import Response @@ -168,8 +167,13 @@ async def predict( self, model_name: str, request: Request, body_raw: bytes = Depends(parse_body) ) -> Response: """ - This method calls the user-provided predict method + This method calls the user-provided predict method. """ + if await request.is_disconnected(): + msg = "Skipping `predict`, client disconnected." + logging.info(msg) + raise ClientDisconnect(msg) + model: ModelWrapper = self._safe_lookup_model(model_name) self.check_healthy(model) @@ -336,11 +340,7 @@ def exit_self(): util.kill_child_processes(os.getpid()) sys.exit() - termination_handler_middleware = TerminationHandlerMiddleware( - on_stop=lambda: None, - on_term=exit_self, - ) - app.add_middleware(BaseHTTPMiddleware, dispatch=termination_handler_middleware) + app.add_middleware(TerminationHandlerMiddleware, on_term=exit_self) return app def start(self): diff --git a/truss/tests/templates/server/common/test_termination_handler_middleware.py b/truss/tests/templates/server/common/test_termination_handler_middleware.py index cd98edead..f4bb2607c 100644 --- a/truss/tests/templates/server/common/test_termination_handler_middleware.py +++ b/truss/tests/templates/server/common/test_termination_handler_middleware.py @@ -1,93 +1,185 @@ +import asyncio +import logging import multiprocessing +import os +import signal +import socket +import sys import tempfile import time from pathlib import Path -from typing import Awaitable, Callable, List +import httpx import pytest -from truss.templates.server.common.termination_handler_middleware import ( - TerminationHandlerMiddleware, -) +from fastapi import FastAPI +from starlette.responses import PlainTextResponse -async def noop(*args, **kwargs): - return +def _get_free_port() -> int: + """Find and return a free port on localhost.""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("localhost", 0)) # Bind to localhost on an arbitrary free port + return s.getsockname()[1] # Return the assigned port -@pytest.mark.integration -def test_termination_sequence_no_pending_requests(tmp_path): - # Create middleware in separate process, on sending term signal to process, - # it should print the right messages. - def main_coro_gen(middleware: TerminationHandlerMiddleware): - import asyncio +HOST = "localhost" +PORT = _get_free_port() - async def main(*args, **kwargs): - await middleware(1, call_next=noop) - await asyncio.sleep(1) - print("should not print due to termination") - return main() +async def _mock_asgi_app(): + await asyncio.sleep(1) + return PlainTextResponse("OK") - _verify_term(main_coro_gen, ["stopped", "terminated"]) +def _on_termination(): + from truss.templates.shared import util -@pytest.mark.integration -def test_termination_sequence_with_pending_requests(tmp_path): - def main_coro_gen(middleware: TerminationHandlerMiddleware): - import asyncio + logging.info("Server is shutting down...") + util.kill_child_processes(os.getpid()) + os.kill(os.getpid(), signal.SIGKILL) - async def main(*args, **kwargs): - async def call_next(req): - await asyncio.sleep(1.0) - return "call_next_called" - resp = await middleware(1, call_next=call_next) - print(f"call_next response: {resp}") - await asyncio.sleep(1) - print("should not print due to termination") +def run_server(log_file_path: Path): + import logging - return main() - - _verify_term( - main_coro_gen, - [ - "stopped", - "call_next response: call_next_called", - "terminated", - ], + logging.basicConfig( + filename=log_file_path, + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", # Add timestamps + datefmt="%Y-%m-%d %H:%M:%S", # Optional: specify date format + force=True, # Force reconfiguration of logging if already configured + ) + import uvicorn + from truss.templates.server.common.termination_handler_middleware import ( + TerminationHandlerMiddleware, ) - -def _verify_term( - main_coro_gen: Callable[[TerminationHandlerMiddleware], Awaitable], - expected_lines: List[str], -): - def run(stdout_capture_file_path): - import asyncio - import os - import signal - import sys - - sys.stdout = open(stdout_capture_file_path, "w") - - def term(): - print("terminated", flush=True) - os.kill(os.getpid(), signal.SIGKILL) - - middleware = TerminationHandlerMiddleware( - on_stop=lambda: print("stopped", flush=True), - on_term=term, - termination_delay_secs=0.1, + app = FastAPI() + app.get("/")(_mock_asgi_app) + app.add_middleware( + TerminationHandlerMiddleware, + on_termination=_on_termination, + termination_delay_secs=1, + ) + # Simple hack to get *all* output to a file. + sys.stderr = open(log_file_path, "a+") + sys.stdout = open(log_file_path, "a+") + uvicorn.run(app, host=HOST, port=PORT) + + +@pytest.mark.asyncio +async def test_no_outstanding_requests_immediate_termination(): + """Test that the server terminates immediately when no outstanding requests.""" + with tempfile.NamedTemporaryFile( + delete=False, prefix="test-term.", suffix=".txt" + ) as tmp_log: + log_file_path = Path(tmp_log.name) + server_process = multiprocessing.Process( + target=run_server, args=(log_file_path,) + ) + server_process.start() + time.sleep(1) + server_process.terminate() + server_process.join() + + with log_file_path.open() as log: + log_lines = log.readlines() + assert any("Received termination signal." in line for line in log_lines) + assert any( + "No outstanding requests. Terminate immediately." in line + for line in log_lines + ) + assert any("Terminating" in line for line in log_lines) + assert any("Server is shutting down" in line for line in log_lines) + + +@pytest.mark.asyncio +async def test_outstanding_requests_delayed_termination(): + """Test that the server waits for outstanding requests to finish before terminating.""" + with tempfile.NamedTemporaryFile( + delete=False, prefix="test-term.", suffix=".txt" + ) as tmp_log: + log_file_path = Path(tmp_log.name) + + server_process = multiprocessing.Process( + target=run_server, args=(log_file_path,) + ) + server_process.start() + time.sleep(1) + + # Send a long-running request to the server + async with httpx.AsyncClient() as client: + task = asyncio.create_task(client.get(f"http://{HOST}:{PORT}/")) + # Give the request some time to be in progress + await asyncio.sleep(0.5) + # Send termination signal (SIGTERM) during the request + server_process.terminate() + response = await task + assert response.status_code == 200 + + server_process.join() + with log_file_path.open() as log: + log_lines = log.readlines() + assert any("Received termination signal." in line for line in log_lines) + assert any( + "Will terminate when all requests are processed." in line + for line in log_lines + ) + assert any("Terminating" in line for line in log_lines) + assert any("Server is shutting down" in line for line in log_lines) + + +@pytest.mark.asyncio +async def test_multiple_outstanding_requests(): + """Test that the server waits for multiple concurrent requests before terminating. + + Logs something like: + + INFO: Started server process [1820944] + INFO: Waiting for application startup. + INFO: Application startup complete. + INFO: Uvicorn running on http://localhost:37311 (Press CTRL+C to quit) + 2024-09-30 15:03:05 - root - INFO - Received termination signal. + 2024-09-30 15:03:05 - root - INFO - Will terminate when all requests are processed. + INFO: 127.0.0.1:58184 - "GET / HTTP/1.1" 200 OK + INFO: 127.0.0.1:58192 - "GET / HTTP/1.1" 200 OK + 2024-09-30 15:03:06 - root - INFO - Termination after finishing outstanding requests. + 2024-09-30 15:03:06 - root - INFO - Sleeping before termination. + 2024-09-30 15:03:07 - root - INFO - Terminating + 2024-09-30 15:03:07 - root - INFO - Server is shutting down... + """ + with tempfile.NamedTemporaryFile( + delete=False, prefix="test-term.", suffix=".txt" + ) as tmp_log: + log_file_path = Path(tmp_log.name) + + server_process = multiprocessing.Process( + target=run_server, args=(log_file_path,) ) - asyncio.run(main_coro_gen(middleware)) - - stdout_capture_file = tempfile.NamedTemporaryFile() - proc = multiprocessing.Process(target=run, args=(stdout_capture_file.name,)) - proc.start() - time.sleep(1) - proc.terminate() - proc.join(timeout=6.0) - with Path(stdout_capture_file.name).open() as file: - lines = [line.strip() for line in file] - - assert lines == expected_lines + server_process.start() + time.sleep(1) + + # Send multiple concurrent long-running requests + async with httpx.AsyncClient() as client: + tasks = [ + asyncio.create_task(client.get(f"http://{HOST}:{PORT}/")), + asyncio.create_task(client.get(f"http://{HOST}:{PORT}/")), + ] + # Give the requests some time to be in progress + await asyncio.sleep(0.5) + server_process.terminate() + # Wait for both requests to finish + results = await asyncio.gather(*tasks) + for r in results: + assert r.status_code == 200 + + server_process.join() + with log_file_path.open() as log: + log_lines = log.readlines() + assert any("Received termination signal." in line for line in log_lines) + assert any( + "Will terminate when all requests are processed." in line + for line in log_lines + ) + assert any("Terminating" in line for line in log_lines) + assert any("Server is shutting down" in line for line in log_lines) diff --git a/truss/tests/test_model_inference.py b/truss/tests/test_model_inference.py index 9aaa73dd7..2df106cda 100644 --- a/truss/tests/test_model_inference.py +++ b/truss/tests/test_model_inference.py @@ -1109,6 +1109,7 @@ def test_async_streaming_with_cancellation(): class Model: async def predict(self, inputs, request: fastapi.Request): + logging.warning("Starting sleep.") await asyncio.sleep(1) if await request.is_disconnected(): logging.warning("Cancelled (before gen).") @@ -1145,10 +1146,13 @@ def test_async_streaming_with_cancellation_before_generation(): class Model: async def predict(self, inputs, request: fastapi.Request): + logging.info("start sleep") await asyncio.sleep(2) + logging.info("done sleep, check request.") if await request.is_disconnected(): logging.warning("Cancelled (before gen).") return + logging.info("not cancelled.") return "Done" """ with ensure_kill_all(), temp_truss(model, "") as tr: From 50b58304b2643df55c456594d1d7b3bfda11611e Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Mon, 30 Sep 2024 15:50:38 -0700 Subject: [PATCH 3/7] Create RC --- poetry.lock | 139 +++++++++++++------------ pyproject.toml | 2 +- truss/templates/server/truss_server.py | 2 +- 3 files changed, 76 insertions(+), 67 deletions(-) diff --git a/poetry.lock b/poetry.lock index 8e635884b..974d5175a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.0 and should not be changed by hand. [[package]] name = "annotated-types" @@ -168,17 +168,17 @@ files = [ [[package]] name = "boto3" -version = "1.35.26" +version = "1.35.29" description = "The AWS SDK for Python" optional = false python-versions = ">=3.8" files = [ - {file = "boto3-1.35.26-py3-none-any.whl", hash = "sha256:c31db992655db233d98762612690cfe60723c9e1503b5709aad92c1c564877bb"}, - {file = "boto3-1.35.26.tar.gz", hash = "sha256:b04087afd3570ba540fd293823c77270ec675672af23da9396bd5988a3f8128b"}, + {file = "boto3-1.35.29-py3-none-any.whl", hash = "sha256:2244044cdfa8ac345d7400536dc15a4824835e7ec5c55bc267e118af66bb27db"}, + {file = "boto3-1.35.29.tar.gz", hash = "sha256:7bbb1ee649e09e956952285782cfdebd7e81fc78384f48dfab3d66c6eaf3f63f"}, ] [package.dependencies] -botocore = ">=1.35.26,<1.36.0" +botocore = ">=1.35.29,<1.36.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.10.0,<0.11.0" @@ -187,13 +187,13 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.35.26" +version = "1.35.29" description = "Low-level, data-driven core of boto 3." optional = false python-versions = ">=3.8" files = [ - {file = "botocore-1.35.26-py3-none-any.whl", hash = "sha256:0b9dee5e4a3314e251e103585837506b17fcc7485c3c8adb61a9a913f46da1e7"}, - {file = "botocore-1.35.26.tar.gz", hash = "sha256:19efc3a22c9df77960712b4e203f912486f8bcd3794bff0fd7b2a0f5f1d5712d"}, + {file = "botocore-1.35.29-py3-none-any.whl", hash = "sha256:f8e3ae0d84214eff3fb69cb4dc51cea6c43d3bde82027a94d00c52b941d6c3d5"}, + {file = "botocore-1.35.29.tar.gz", hash = "sha256:4ed28ab03675bb008a290c452c5ddd7aaa5d4e3fa1912aadbdf93057ee84362b"}, ] [package.dependencies] @@ -948,61 +948,70 @@ grpc = ["grpcio (>=1.44.0,<2.0.0.dev0)"] [[package]] name = "grpcio" -version = "1.66.1" +version = "1.66.2" description = "HTTP/2-based RPC framework" optional = false python-versions = ">=3.8" files = [ - {file = "grpcio-1.66.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:4877ba180591acdf127afe21ec1c7ff8a5ecf0fe2600f0d3c50e8c4a1cbc6492"}, - {file = "grpcio-1.66.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:3750c5a00bd644c75f4507f77a804d0189d97a107eb1481945a0cf3af3e7a5ac"}, - {file = "grpcio-1.66.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:a013c5fbb12bfb5f927444b477a26f1080755a931d5d362e6a9a720ca7dbae60"}, - {file = "grpcio-1.66.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b1b24c23d51a1e8790b25514157d43f0a4dce1ac12b3f0b8e9f66a5e2c4c132f"}, - {file = "grpcio-1.66.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b7ffb8ea674d68de4cac6f57d2498fef477cef582f1fa849e9f844863af50083"}, - {file = "grpcio-1.66.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:307b1d538140f19ccbd3aed7a93d8f71103c5d525f3c96f8616111614b14bf2a"}, - {file = "grpcio-1.66.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:1c17ebcec157cfb8dd445890a03e20caf6209a5bd4ac5b040ae9dbc59eef091d"}, - {file = "grpcio-1.66.1-cp310-cp310-win32.whl", hash = "sha256:ef82d361ed5849d34cf09105d00b94b6728d289d6b9235513cb2fcc79f7c432c"}, - {file = "grpcio-1.66.1-cp310-cp310-win_amd64.whl", hash = "sha256:292a846b92cdcd40ecca46e694997dd6b9be6c4c01a94a0dfb3fcb75d20da858"}, - {file = "grpcio-1.66.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:c30aeceeaff11cd5ddbc348f37c58bcb96da8d5aa93fed78ab329de5f37a0d7a"}, - {file = "grpcio-1.66.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8a1e224ce6f740dbb6b24c58f885422deebd7eb724aff0671a847f8951857c26"}, - {file = "grpcio-1.66.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:a66fe4dc35d2330c185cfbb42959f57ad36f257e0cc4557d11d9f0a3f14311df"}, - {file = "grpcio-1.66.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e3ba04659e4fce609de2658fe4dbf7d6ed21987a94460f5f92df7579fd5d0e22"}, - {file = "grpcio-1.66.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4573608e23f7e091acfbe3e84ac2045680b69751d8d67685ffa193a4429fedb1"}, - {file = "grpcio-1.66.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7e06aa1f764ec8265b19d8f00140b8c4b6ca179a6dc67aa9413867c47e1fb04e"}, - {file = "grpcio-1.66.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3885f037eb11f1cacc41f207b705f38a44b69478086f40608959bf5ad85826dd"}, - {file = "grpcio-1.66.1-cp311-cp311-win32.whl", hash = "sha256:97ae7edd3f3f91480e48ede5d3e7d431ad6005bfdbd65c1b56913799ec79e791"}, - {file = "grpcio-1.66.1-cp311-cp311-win_amd64.whl", hash = "sha256:cfd349de4158d797db2bd82d2020554a121674e98fbe6b15328456b3bf2495bb"}, - {file = "grpcio-1.66.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:a92c4f58c01c77205df6ff999faa008540475c39b835277fb8883b11cada127a"}, - {file = "grpcio-1.66.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:fdb14bad0835914f325349ed34a51940bc2ad965142eb3090081593c6e347be9"}, - {file = "grpcio-1.66.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:f03a5884c56256e08fd9e262e11b5cfacf1af96e2ce78dc095d2c41ccae2c80d"}, - {file = "grpcio-1.66.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2ca2559692d8e7e245d456877a85ee41525f3ed425aa97eb7a70fc9a79df91a0"}, - {file = "grpcio-1.66.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:84ca1be089fb4446490dd1135828bd42a7c7f8421e74fa581611f7afdf7ab761"}, - {file = "grpcio-1.66.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:d639c939ad7c440c7b2819a28d559179a4508783f7e5b991166f8d7a34b52815"}, - {file = "grpcio-1.66.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:b9feb4e5ec8dc2d15709f4d5fc367794d69277f5d680baf1910fc9915c633524"}, - {file = "grpcio-1.66.1-cp312-cp312-win32.whl", hash = "sha256:7101db1bd4cd9b880294dec41a93fcdce465bdbb602cd8dc5bd2d6362b618759"}, - {file = "grpcio-1.66.1-cp312-cp312-win_amd64.whl", hash = "sha256:b0aa03d240b5539648d996cc60438f128c7f46050989e35b25f5c18286c86734"}, - {file = "grpcio-1.66.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:ecfe735e7a59e5a98208447293ff8580e9db1e890e232b8b292dc8bd15afc0d2"}, - {file = "grpcio-1.66.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:4825a3aa5648010842e1c9d35a082187746aa0cdbf1b7a2a930595a94fb10fce"}, - {file = "grpcio-1.66.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:f517fd7259fe823ef3bd21e508b653d5492e706e9f0ef82c16ce3347a8a5620c"}, - {file = "grpcio-1.66.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f1fe60d0772831d96d263b53d83fb9a3d050a94b0e94b6d004a5ad111faa5b5b"}, - {file = "grpcio-1.66.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31a049daa428f928f21090403e5d18ea02670e3d5d172581670be006100db9ef"}, - {file = "grpcio-1.66.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6f914386e52cbdeb5d2a7ce3bf1fdfacbe9d818dd81b6099a05b741aaf3848bb"}, - {file = "grpcio-1.66.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:bff2096bdba686019fb32d2dde45b95981f0d1490e054400f70fc9a8af34b49d"}, - {file = "grpcio-1.66.1-cp38-cp38-win32.whl", hash = "sha256:aa8ba945c96e73de29d25331b26f3e416e0c0f621e984a3ebdb2d0d0b596a3b3"}, - {file = "grpcio-1.66.1-cp38-cp38-win_amd64.whl", hash = "sha256:161d5c535c2bdf61b95080e7f0f017a1dfcb812bf54093e71e5562b16225b4ce"}, - {file = "grpcio-1.66.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:d0cd7050397b3609ea51727b1811e663ffda8bda39c6a5bb69525ef12414b503"}, - {file = "grpcio-1.66.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0e6c9b42ded5d02b6b1fea3a25f036a2236eeb75d0579bfd43c0018c88bf0a3e"}, - {file = "grpcio-1.66.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:c9f80f9fad93a8cf71c7f161778ba47fd730d13a343a46258065c4deb4b550c0"}, - {file = "grpcio-1.66.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5dd67ed9da78e5121efc5c510f0122a972216808d6de70953a740560c572eb44"}, - {file = "grpcio-1.66.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:48b0d92d45ce3be2084b92fb5bae2f64c208fea8ceed7fccf6a7b524d3c4942e"}, - {file = "grpcio-1.66.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:4d813316d1a752be6f5c4360c49f55b06d4fe212d7df03253dfdae90c8a402bb"}, - {file = "grpcio-1.66.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9c9bebc6627873ec27a70fc800f6083a13c70b23a5564788754b9ee52c5aef6c"}, - {file = "grpcio-1.66.1-cp39-cp39-win32.whl", hash = "sha256:30a1c2cf9390c894c90bbc70147f2372130ad189cffef161f0432d0157973f45"}, - {file = "grpcio-1.66.1-cp39-cp39-win_amd64.whl", hash = "sha256:17663598aadbedc3cacd7bbde432f541c8e07d2496564e22b214b22c7523dac8"}, - {file = "grpcio-1.66.1.tar.gz", hash = "sha256:35334f9c9745add3e357e3372756fd32d925bd52c41da97f4dfdafbde0bf0ee2"}, -] - -[package.extras] -protobuf = ["grpcio-tools (>=1.66.1)"] + {file = "grpcio-1.66.2-cp310-cp310-linux_armv7l.whl", hash = "sha256:fe96281713168a3270878255983d2cb1a97e034325c8c2c25169a69289d3ecfa"}, + {file = "grpcio-1.66.2-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:73fc8f8b9b5c4a03e802b3cd0c18b2b06b410d3c1dcbef989fdeb943bd44aff7"}, + {file = "grpcio-1.66.2-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:03b0b307ba26fae695e067b94cbb014e27390f8bc5ac7a3a39b7723fed085604"}, + {file = "grpcio-1.66.2-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d69ce1f324dc2d71e40c9261d3fdbe7d4c9d60f332069ff9b2a4d8a257c7b2b"}, + {file = "grpcio-1.66.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:05bc2ceadc2529ab0b227b1310d249d95d9001cd106aa4d31e8871ad3c428d73"}, + {file = "grpcio-1.66.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:8ac475e8da31484efa25abb774674d837b343afb78bb3bcdef10f81a93e3d6bf"}, + {file = "grpcio-1.66.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:0be4e0490c28da5377283861bed2941d1d20ec017ca397a5df4394d1c31a9b50"}, + {file = "grpcio-1.66.2-cp310-cp310-win32.whl", hash = "sha256:4e504572433f4e72b12394977679161d495c4c9581ba34a88d843eaf0f2fbd39"}, + {file = "grpcio-1.66.2-cp310-cp310-win_amd64.whl", hash = "sha256:2018b053aa15782db2541ca01a7edb56a0bf18c77efed975392583725974b249"}, + {file = "grpcio-1.66.2-cp311-cp311-linux_armv7l.whl", hash = "sha256:2335c58560a9e92ac58ff2bc5649952f9b37d0735608242973c7a8b94a6437d8"}, + {file = "grpcio-1.66.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:45a3d462826f4868b442a6b8fdbe8b87b45eb4f5b5308168c156b21eca43f61c"}, + {file = "grpcio-1.66.2-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:a9539f01cb04950fd4b5ab458e64a15f84c2acc273670072abe49a3f29bbad54"}, + {file = "grpcio-1.66.2-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce89f5876662f146d4c1f695dda29d4433a5d01c8681fbd2539afff535da14d4"}, + {file = "grpcio-1.66.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d25a14af966438cddf498b2e338f88d1c9706f3493b1d73b93f695c99c5f0e2a"}, + {file = "grpcio-1.66.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:6001e575b8bbd89eee11960bb640b6da6ae110cf08113a075f1e2051cc596cae"}, + {file = "grpcio-1.66.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:4ea1d062c9230278793820146c95d038dc0f468cbdd172eec3363e42ff1c7d01"}, + {file = "grpcio-1.66.2-cp311-cp311-win32.whl", hash = "sha256:38b68498ff579a3b1ee8f93a05eb48dc2595795f2f62716e797dc24774c1aaa8"}, + {file = "grpcio-1.66.2-cp311-cp311-win_amd64.whl", hash = "sha256:6851de821249340bdb100df5eacfecfc4e6075fa85c6df7ee0eb213170ec8e5d"}, + {file = "grpcio-1.66.2-cp312-cp312-linux_armv7l.whl", hash = "sha256:802d84fd3d50614170649853d121baaaa305de7b65b3e01759247e768d691ddf"}, + {file = "grpcio-1.66.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:80fd702ba7e432994df208f27514280b4b5c6843e12a48759c9255679ad38db8"}, + {file = "grpcio-1.66.2-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:12fda97ffae55e6526825daf25ad0fa37483685952b5d0f910d6405c87e3adb6"}, + {file = "grpcio-1.66.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:950da58d7d80abd0ea68757769c9db0a95b31163e53e5bb60438d263f4bed7b7"}, + {file = "grpcio-1.66.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e636ce23273683b00410f1971d209bf3689238cf5538d960adc3cdfe80dd0dbd"}, + {file = "grpcio-1.66.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:a917d26e0fe980b0ac7bfcc1a3c4ad6a9a4612c911d33efb55ed7833c749b0ee"}, + {file = "grpcio-1.66.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:49f0ca7ae850f59f828a723a9064cadbed90f1ece179d375966546499b8a2c9c"}, + {file = "grpcio-1.66.2-cp312-cp312-win32.whl", hash = "sha256:31fd163105464797a72d901a06472860845ac157389e10f12631025b3e4d0453"}, + {file = "grpcio-1.66.2-cp312-cp312-win_amd64.whl", hash = "sha256:ff1f7882e56c40b0d33c4922c15dfa30612f05fb785074a012f7cda74d1c3679"}, + {file = "grpcio-1.66.2-cp313-cp313-linux_armv7l.whl", hash = "sha256:3b00efc473b20d8bf83e0e1ae661b98951ca56111feb9b9611df8efc4fe5d55d"}, + {file = "grpcio-1.66.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:1caa38fb22a8578ab8393da99d4b8641e3a80abc8fd52646f1ecc92bcb8dee34"}, + {file = "grpcio-1.66.2-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:c408f5ef75cfffa113cacd8b0c0e3611cbfd47701ca3cdc090594109b9fcbaed"}, + {file = "grpcio-1.66.2-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c806852deaedee9ce8280fe98955c9103f62912a5b2d5ee7e3eaa284a6d8d8e7"}, + {file = "grpcio-1.66.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f145cc21836c332c67baa6fc81099d1d27e266401565bf481948010d6ea32d46"}, + {file = "grpcio-1.66.2-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:73e3b425c1e155730273f73e419de3074aa5c5e936771ee0e4af0814631fb30a"}, + {file = "grpcio-1.66.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:9c509a4f78114cbc5f0740eb3d7a74985fd2eff022971bc9bc31f8bc93e66a3b"}, + {file = "grpcio-1.66.2-cp313-cp313-win32.whl", hash = "sha256:20657d6b8cfed7db5e11b62ff7dfe2e12064ea78e93f1434d61888834bc86d75"}, + {file = "grpcio-1.66.2-cp313-cp313-win_amd64.whl", hash = "sha256:fb70487c95786e345af5e854ffec8cb8cc781bcc5df7930c4fbb7feaa72e1cdf"}, + {file = "grpcio-1.66.2-cp38-cp38-linux_armv7l.whl", hash = "sha256:a18e20d8321c6400185b4263e27982488cb5cdd62da69147087a76a24ef4e7e3"}, + {file = "grpcio-1.66.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:02697eb4a5cbe5a9639f57323b4c37bcb3ab2d48cec5da3dc2f13334d72790dd"}, + {file = "grpcio-1.66.2-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:99a641995a6bc4287a6315989ee591ff58507aa1cbe4c2e70d88411c4dcc0839"}, + {file = "grpcio-1.66.2-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3ed71e81782966ffead60268bbda31ea3f725ebf8aa73634d5dda44f2cf3fb9c"}, + {file = "grpcio-1.66.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbd27c24a4cc5e195a7f56cfd9312e366d5d61b86e36d46bbe538457ea6eb8dd"}, + {file = "grpcio-1.66.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d9a9724a156c8ec6a379869b23ba3323b7ea3600851c91489b871e375f710bc8"}, + {file = "grpcio-1.66.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d8d4732cc5052e92cea2f78b233c2e2a52998ac40cd651f40e398893ad0d06ec"}, + {file = "grpcio-1.66.2-cp38-cp38-win32.whl", hash = "sha256:7b2c86457145ce14c38e5bf6bdc19ef88e66c5fee2c3d83285c5aef026ba93b3"}, + {file = "grpcio-1.66.2-cp38-cp38-win_amd64.whl", hash = "sha256:e88264caad6d8d00e7913996030bac8ad5f26b7411495848cc218bd3a9040b6c"}, + {file = "grpcio-1.66.2-cp39-cp39-linux_armv7l.whl", hash = "sha256:c400ba5675b67025c8a9f48aa846f12a39cf0c44df5cd060e23fda5b30e9359d"}, + {file = "grpcio-1.66.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:66a0cd8ba6512b401d7ed46bb03f4ee455839957f28b8d61e7708056a806ba6a"}, + {file = "grpcio-1.66.2-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:06de8ec0bd71be123eec15b0e0d457474931c2c407869b6c349bd9bed4adbac3"}, + {file = "grpcio-1.66.2-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb57870449dfcfac428afbb5a877829fcb0d6db9d9baa1148705739e9083880e"}, + {file = "grpcio-1.66.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b672abf90a964bfde2d0ecbce30f2329a47498ba75ce6f4da35a2f4532b7acbc"}, + {file = "grpcio-1.66.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:ad2efdbe90c73b0434cbe64ed372e12414ad03c06262279b104a029d1889d13e"}, + {file = "grpcio-1.66.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9c3a99c519f4638e700e9e3f83952e27e2ea10873eecd7935823dab0c1c9250e"}, + {file = "grpcio-1.66.2-cp39-cp39-win32.whl", hash = "sha256:78fa51ebc2d9242c0fc5db0feecc57a9943303b46664ad89921f5079e2e4ada7"}, + {file = "grpcio-1.66.2-cp39-cp39-win_amd64.whl", hash = "sha256:728bdf36a186e7f51da73be7f8d09457a03061be848718d0edf000e709418987"}, + {file = "grpcio-1.66.2.tar.gz", hash = "sha256:563588c587b75c34b928bc428548e5b00ea38c46972181a4d8b75ba7e3f24231"}, +] + +[package.extras] +protobuf = ["grpcio-tools (>=1.66.2)"] [[package]] name = "h11" @@ -2190,13 +2199,13 @@ virtualenv = ">=20.10.0" [[package]] name = "prompt-toolkit" -version = "3.0.47" +version = "3.0.48" description = "Library for building powerful interactive command lines in Python" optional = false python-versions = ">=3.7.0" files = [ - {file = "prompt_toolkit-3.0.47-py3-none-any.whl", hash = "sha256:0d7bfa67001d5e39d02c224b663abc33687405033a8c422d0d675a5a13361d10"}, - {file = "prompt_toolkit-3.0.47.tar.gz", hash = "sha256:1e1b29cb58080b1e69f207c893a1a7bf16d127a5c30c9d17a25a5d77792e5360"}, + {file = "prompt_toolkit-3.0.48-py3-none-any.whl", hash = "sha256:f49a827f90062e411f1ce1f854f2aedb3c23353244f8108b89283587397ac10e"}, + {file = "prompt_toolkit-3.0.48.tar.gz", hash = "sha256:d6623ab0477a80df74e646bdbc93621143f5caf104206aa29294d53de1a03d90"}, ] [package.dependencies] @@ -3430,13 +3439,13 @@ test = ["Cython (>=0.29.36,<0.30.0)", "aiohttp (==3.9.0b0)", "aiohttp (>=3.8.1)" [[package]] name = "virtualenv" -version = "20.26.5" +version = "20.26.6" description = "Virtual Python Environment builder" optional = false python-versions = ">=3.7" files = [ - {file = "virtualenv-20.26.5-py3-none-any.whl", hash = "sha256:4f3ac17b81fba3ce3bd6f4ead2749a72da5929c01774948e243db9ba41df4ff6"}, - {file = "virtualenv-20.26.5.tar.gz", hash = "sha256:ce489cac131aa58f4b25e321d6d186171f78e6cb13fafbf32a840cee67733ff4"}, + {file = "virtualenv-20.26.6-py3-none-any.whl", hash = "sha256:7345cc5b25405607a624d8418154577459c3e0277f5466dd79c49d5e492995f2"}, + {file = "virtualenv-20.26.6.tar.gz", hash = "sha256:280aede09a2a5c317e409a00102e7077c6432c5a38f0ef938e643805a7ad2c48"}, ] [package.dependencies] diff --git a/pyproject.toml b/pyproject.toml index 5e32d721e..fa9c2168b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.40" +version = "0.9.41rc002" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss/templates/server/truss_server.py b/truss/templates/server/truss_server.py index e4ffedc70..74f489907 100644 --- a/truss/templates/server/truss_server.py +++ b/truss/templates/server/truss_server.py @@ -340,7 +340,7 @@ def exit_self(): util.kill_child_processes(os.getpid()) sys.exit() - app.add_middleware(TerminationHandlerMiddleware, on_term=exit_self) + app.add_middleware(TerminationHandlerMiddleware, on_termination=exit_self) return app def start(self): From 39b8527b11f55f07545e1131ed88099d33502a97 Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Mon, 30 Sep 2024 17:32:54 -0700 Subject: [PATCH 4/7] Remove sleep when no requests, fix test --- .../server/common/termination_handler_middleware.py | 6 +++--- .../server/common/test_termination_handler_middleware.py | 1 - truss/tests/test_model_inference.py | 6 +++--- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/truss/templates/server/common/termination_handler_middleware.py b/truss/templates/server/common/termination_handler_middleware.py index 2c09fd7db..cd69bdd57 100644 --- a/truss/templates/server/common/termination_handler_middleware.py +++ b/truss/templates/server/common/termination_handler_middleware.py @@ -51,7 +51,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: ): logging.info("Termination after finishing outstanding requests.") # Run in background, to not block the current request handling. - asyncio.create_task(self._terminate()) + asyncio.create_task(self._terminate_with_sleep()) else: await self._app(scope, receive, send) @@ -60,11 +60,11 @@ def _handle_stop_signal(self) -> None: self._should_terminate_soon = True if self._outstanding_requests_semaphore.locked(): logging.info("No outstanding requests. Terminate immediately.") - asyncio.create_task(self._terminate()) + self._on_termination() else: logging.info("Will terminate when all requests are processed.") - async def _terminate(self) -> None: + async def _terminate_with_sleep(self) -> None: logging.info("Sleeping before termination.") await asyncio.sleep(self._termination_delay_secs) logging.info("Terminating") diff --git a/truss/tests/templates/server/common/test_termination_handler_middleware.py b/truss/tests/templates/server/common/test_termination_handler_middleware.py index f4bb2607c..2f9c66d61 100644 --- a/truss/tests/templates/server/common/test_termination_handler_middleware.py +++ b/truss/tests/templates/server/common/test_termination_handler_middleware.py @@ -89,7 +89,6 @@ async def test_no_outstanding_requests_immediate_termination(): "No outstanding requests. Terminate immediately." in line for line in log_lines ) - assert any("Terminating" in line for line in log_lines) assert any("Server is shutting down" in line for line in log_lines) diff --git a/truss/tests/test_model_inference.py b/truss/tests/test_model_inference.py index 2df106cda..364efd5cc 100644 --- a/truss/tests/test_model_inference.py +++ b/truss/tests/test_model_inference.py @@ -1140,19 +1140,19 @@ async def predict(self, inputs, request: fastapi.Request): @pytest.mark.integration -def test_async_streaming_with_cancellation_before_generation(): +def test_async_non_streaming_with_cancellation(): model = """ import fastapi, asyncio, logging class Model: async def predict(self, inputs, request: fastapi.Request): - logging.info("start sleep") + logging.info("Start sleep") await asyncio.sleep(2) logging.info("done sleep, check request.") if await request.is_disconnected(): logging.warning("Cancelled (before gen).") return - logging.info("not cancelled.") + logging.info("Not cancelled.") return "Done" """ with ensure_kill_all(), temp_truss(model, "") as tr: From 8fda03c530d61c9913121e47b7037339dde172f0 Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Tue, 1 Oct 2024 12:07:57 -0700 Subject: [PATCH 5/7] Bump Version for CTX builder --- pyproject.toml | 2 +- .../server/common/termination_handler_middleware.py | 2 +- truss/templates/server/truss_server.py | 5 ++++- .../server/common/test_termination_handler_middleware.py | 2 +- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fa9c2168b..ca97f03bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.41rc002" +version = "0.9.41rc005" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss/templates/server/common/termination_handler_middleware.py b/truss/templates/server/common/termination_handler_middleware.py index cd69bdd57..4d8bc0d54 100644 --- a/truss/templates/server/common/termination_handler_middleware.py +++ b/truss/templates/server/common/termination_handler_middleware.py @@ -57,11 +57,11 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: def _handle_stop_signal(self) -> None: logging.info("Received termination signal.") - self._should_terminate_soon = True if self._outstanding_requests_semaphore.locked(): logging.info("No outstanding requests. Terminate immediately.") self._on_termination() else: + self._should_terminate_soon = True logging.info("Will terminate when all requests are processed.") async def _terminate_with_sleep(self) -> None: diff --git a/truss/templates/server/truss_server.py b/truss/templates/server/truss_server.py index 74f489907..5acf544da 100644 --- a/truss/templates/server/truss_server.py +++ b/truss/templates/server/truss_server.py @@ -338,7 +338,10 @@ def exit_self(): # Note that this kills the current process, the worker process, not # the main truss_server process. util.kill_child_processes(os.getpid()) - sys.exit() + try: + sys.exit(0) + except SystemExit: + pass # Exit cleanly without printing the stack trace app.add_middleware(TerminationHandlerMiddleware, on_termination=exit_self) return app diff --git a/truss/tests/templates/server/common/test_termination_handler_middleware.py b/truss/tests/templates/server/common/test_termination_handler_middleware.py index 2f9c66d61..d393aff25 100644 --- a/truss/tests/templates/server/common/test_termination_handler_middleware.py +++ b/truss/tests/templates/server/common/test_termination_handler_middleware.py @@ -132,7 +132,7 @@ async def test_outstanding_requests_delayed_termination(): async def test_multiple_outstanding_requests(): """Test that the server waits for multiple concurrent requests before terminating. - Logs something like: + Logs something like (note termination signal before processing requests): INFO: Started server process [1820944] INFO: Waiting for application startup. From 850519a6be7c7ebed1327345d294ff619448f358 Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Tue, 1 Oct 2024 16:57:36 -0700 Subject: [PATCH 6/7] Bump Version for CTX builder --- pyproject.toml | 2 +- truss/templates/server/truss_server.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ca97f03bc..3f7fe2690 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.41rc005" +version = "0.9.41rc006" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss/templates/server/truss_server.py b/truss/templates/server/truss_server.py index 5acf544da..1e85821a4 100644 --- a/truss/templates/server/truss_server.py +++ b/truss/templates/server/truss_server.py @@ -169,6 +169,7 @@ async def predict( """ This method calls the user-provided predict method. """ + logging.warning("Marius006") if await request.is_disconnected(): msg = "Skipping `predict`, client disconnected." logging.info(msg) From 23ececb3a2abcf382314c955b6f156cf5d029651 Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Tue, 1 Oct 2024 17:29:59 -0700 Subject: [PATCH 7/7] Bump Version for CTX builder --- pyproject.toml | 2 +- .../common/termination_handler_middleware.py | 2 ++ truss/templates/server/truss_server.py | 18 ++++++++++++++---- 3 files changed, 17 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3f7fe2690..052f30372 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.41rc006" +version = "0.9.41rc007" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" diff --git a/truss/templates/server/common/termination_handler_middleware.py b/truss/templates/server/common/termination_handler_middleware.py index 4d8bc0d54..0a43d57f2 100644 --- a/truss/templates/server/common/termination_handler_middleware.py +++ b/truss/templates/server/common/termination_handler_middleware.py @@ -68,4 +68,6 @@ async def _terminate_with_sleep(self) -> None: logging.info("Sleeping before termination.") await asyncio.sleep(self._termination_delay_secs) logging.info("Terminating") + loop = asyncio.get_event_loop() + loop.stop() # Stop the event loop gracefully self._on_termination() diff --git a/truss/templates/server/truss_server.py b/truss/templates/server/truss_server.py index 1e85821a4..24fc7a4de 100644 --- a/truss/templates/server/truss_server.py +++ b/truss/templates/server/truss_server.py @@ -30,6 +30,8 @@ from starlette.requests import ClientDisconnect from starlette.responses import Response +# from starlette.routing import Lifespan + if sys.version_info >= (3, 9): from typing import AsyncGenerator, Generator else: @@ -75,6 +77,15 @@ def run(self): asyncio.run(server.serve(sockets=self.sockets)) +# class CustomLifespan(Lifespan): +# async def __call__(self, scope, receive, send): +# try: +# await super().__call__(scope, receive, send) +# except asyncio.CancelledError: +# # Handle asyncio cancellation error gracefully +# pass + + class BasetenEndpoints: """The implementation of the model server endpoints. @@ -338,11 +349,10 @@ def create_application(self): def exit_self(): # Note that this kills the current process, the worker process, not # the main truss_server process. + # loop = asyncio.get_event_loop() + # loop.stop() # Stop the event loop gracefully util.kill_child_processes(os.getpid()) - try: - sys.exit(0) - except SystemExit: - pass # Exit cleanly without printing the stack trace + sys.exit(0) app.add_middleware(TerminationHandlerMiddleware, on_termination=exit_self) return app