Skip to content

Commit

Permalink
Fix: ASGI processes did not support lifetime events
Browse files Browse the repository at this point in the history
  • Loading branch information
hoh committed Jul 6, 2023
1 parent 0a429f5 commit 1c72da3
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 588 deletions.
16 changes: 16 additions & 0 deletions examples/example_fastapi/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@
app = AlephApp(http_app=http_app)
cache = VmCache()

startup_lifespan_executed: bool = False


@app.on_event("startup")
async def startup_event():
global startup_lifespan_executed
startup_lifespan_executed = True


@app.get("/")
async def index():
Expand Down Expand Up @@ -59,6 +67,14 @@ async def index():
}


@app.get("/lifespan")
async def check_lifespan():
"""
Check that ASGI lifespan startup signal has been received
"""
return {"Lifetime": startup_lifespan_executed}


@app.get("/environ")
async def environ() -> Dict[str, str]:
"""List environment variables"""
Expand Down
85 changes: 70 additions & 15 deletions runtimes/aleph-debian-11-python/init1.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,18 @@
from io import StringIO
from os import system
from shutil import make_archive
from typing import Any, AsyncIterable, Dict, List, NewType, Optional, Tuple, Union
from typing import (
Any,
AsyncIterable,
Dict,
List,
Literal,
NewType,
Optional,
Tuple,
Union,
cast,
)

import aiohttp
import msgpack
Expand Down Expand Up @@ -62,9 +73,9 @@ class ConfigurationPayload:
input_data: bytes
interface: Interface
vm_hash: str
code: bytes = None
encoding: Encoding = None
entrypoint: str = None
code: bytes
encoding: Encoding
entrypoint: str
ip: Optional[str] = None
ipv6: Optional[str] = None
route: Optional[str] = None
Expand Down Expand Up @@ -171,7 +182,9 @@ def setup_input_data(input_data: bytes):
os.system("unzip -q /opt/input.zip -d /data")


def setup_authorized_keys(authorized_keys: List[str]) -> None:
def setup_authorized_keys(authorized_keys: Optional[List[str]]) -> None:
if authorized_keys is None:
return
path = Path("/root/.ssh/authorized_keys")
path.parent.mkdir(exist_ok=True)
path.write_text("\n".join(key for key in authorized_keys))
Expand All @@ -189,7 +202,40 @@ def setup_volumes(volumes: List[Volume]):
system("mount")


def setup_code_asgi(
async def wait_for_lifespan_event_completion(
application: ASGIApplication, event: Union[Literal["startup", "shutdown"]]
):
"""
Send the startup lifespan signal to the ASGI app.
Specification: https://asgi.readthedocs.io/en/latest/specs/lifespan.html
"""

lifespan_completion = asyncio.Event()

async def receive():
return {
"type": f"lifespan.{event}",
}

async def send(response: Dict):
response_type = response.get("type")
if response_type == f"lifespan.{event}.complete":
lifespan_completion.set()
return
else:
logger.warning(f"Unexpected response to {event}: {response_type}")

while not lifespan_completion.is_set():
await application(
scope={
"type": "lifespan",
},
receive=receive,
send=send,
)


async def setup_code_asgi(
code: bytes, encoding: Encoding, entrypoint: str
) -> ASGIApplication:
# Allow importing packages from /opt/packages
Expand Down Expand Up @@ -225,6 +271,8 @@ def setup_code_asgi(
app = locals[entrypoint]
else:
raise ValueError(f"Unknown encoding '{encoding}'")

await wait_for_lifespan_event_completion(application=app, event="startup")
return app


Expand Down Expand Up @@ -260,14 +308,16 @@ def setup_code_executable(
return process


def setup_code(
code: Optional[bytes],
encoding: Optional[Encoding],
entrypoint: Optional[str],
async def setup_code(
code: bytes,
encoding: Encoding,
entrypoint: str,
interface: Interface,
) -> Union[ASGIApplication, subprocess.Popen]:
if interface == Interface.asgi:
return setup_code_asgi(code=code, encoding=encoding, entrypoint=entrypoint)
return await setup_code_asgi(
code=code, encoding=encoding, entrypoint=entrypoint
)
elif interface == Interface.executable:
return setup_code_executable(
code=code, encoding=encoding, entrypoint=entrypoint
Expand All @@ -284,15 +334,15 @@ async def run_python_code_http(
# Execute in the same process, saves ~20ms than a subprocess

# The body should not be part of the ASGI scope itself
body: bytes = scope.pop("body")
scope_body: bytes = scope.pop("body")

async def receive():
type_ = (
"http.request"
if scope["type"] in ("http", "websocket")
else "aleph.message"
)
return {"type": type_, "body": body, "more_body": False}
return {"type": type_, "body": scope_body, "more_body": False}

send_queue: asyncio.Queue = asyncio.Queue()

Expand All @@ -317,7 +367,7 @@ async def send(dico):
output = buf.getvalue()

logger.debug(f"Headers {headers}")
logger.debug(f"Body {body}")
logger.debug(f"Body {body!r}")
logger.debug(f"Output {output}")

logger.debug("Getting output data")
Expand Down Expand Up @@ -402,6 +452,10 @@ async def process_instruction(
application.terminate()
logger.debug("Application terminated")
# application.communicate()
else:
await wait_for_lifespan_event_completion(
application=application, event="shutdown"
)
yield b"STOP\n"
logger.debug("Supervisor informed of halt")
raise ShutdownException
Expand Down Expand Up @@ -429,6 +483,7 @@ async def process_instruction(
output_data: Optional[bytes]

if interface == Interface.asgi:
application = cast(ASGIApplication, application)
headers, body, output, output_data = await run_python_code_http(
application=application, scope=payload.scope
)
Expand Down Expand Up @@ -520,7 +575,7 @@ async def main() -> None:
setup_system(config)

try:
app: Union[ASGIApplication, subprocess.Popen] = setup_code(
app: Union[ASGIApplication, subprocess.Popen] = await setup_code(
config.code, config.encoding, config.entrypoint, config.interface
)
client.send(msgpack.dumps({"success": True}))
Expand Down
Loading

0 comments on commit 1c72da3

Please sign in to comment.