Skip to content

Commit

Permalink
mega-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
nebfield committed Oct 21, 2024
1 parent bda1ccc commit 7616343
Show file tree
Hide file tree
Showing 12 changed files with 249 additions and 791 deletions.
1 change: 0 additions & 1 deletion pyvatti/.pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,3 @@ repos:
rev: 'v1.11.2' # Use the sha / tag you want to point at
hooks:
- id: mypy
args: [--config-file, mypy.ini]
781 changes: 31 additions & 750 deletions pyvatti/poetry.lock

Large diffs are not rendered by default.

13 changes: 6 additions & 7 deletions pyvatti/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,19 @@ packages = [

[tool.poetry.dependencies]
python = "^3.12"
uvicorn = {extras = ["standard"], version = "^0.29.0"}
pgscatalog-core = "^0.1.0"
transitions = "^0.9.0"
transitions = "^0.9.2"
google-cloud-storage = "^2.18.2"
pydantic-settings = "^2.2.1"
types-pyyaml = "^6.0.12.20240917"
schedule = "^1.2.2"
httpx = "^0.27.2"
pydantic = "^2.9.2"
pyyaml = "^6.0.2"
kafka-python-ng = "^2.2.3"

[tool.poetry.group.dev.dependencies]
pytest = "^8.1.1"
mypy = "^1.12.0"

[tool.poetry.group.standard.dependencies]
uvicorn = "^0.29.0"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Expand Down
9 changes: 0 additions & 9 deletions pyvatti/src/pyvatti/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
import logging
import pathlib
import tempfile

import httpx

logger = logging.getLogger(__name__)
log_fmt = "%(name)s: %(asctime)s %(levelname)-8s %(message)s"
logging.basicConfig(format=log_fmt, datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)

CLIENT = httpx.AsyncClient()
TEMP_DIR = tempfile.mkdtemp()
SHELF_PATH = str(pathlib.Path(TEMP_DIR) / "shelve.dat")
logger.info(f"Created temporary shelf file {SHELF_PATH}")
18 changes: 17 additions & 1 deletion pyvatti/src/pyvatti/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import enum
import pathlib
import sys
from tempfile import NamedTemporaryFile
from typing import Optional

from pydantic import Field, DirectoryPath, AnyHttpUrl
Expand Down Expand Up @@ -45,9 +47,23 @@ class Settings(BaseSettings):
description="Number of seconds to wait before polling Seqera platform API",
)
NOTIFY_TOKEN: str = Field(description="Token for backend notifications")
SQLITE_DB_PATH: pathlib.Path = Field(
description="Path to a sqlite database",
default_factory=lambda: NamedTemporaryFile(delete=False).name,
)


if "pytest" in sys.modules:
settings = None
settings = Settings(
HELM_CHART_PATH="/tmp",
TOWER_TOKEN="test",
TOWER_WORKSPACE="test",
GLOBUS_DOMAIN="https://example.com",
GLOBUS_CLIENT_ID="test",
GLOBUS_CLIENT_SECRET="test",
GLOBUS_SCOPES="test",
NOTIFY_URL="https://example.com",
NOTIFY_TOKEN="test",
)
else:
settings = Settings()
2 changes: 1 addition & 1 deletion pyvatti/src/pyvatti/helm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import yaml
from pydantic import BaseModel, Field, field_validator
from pyvatti.config import settings
from pyvatti.models import JobRequest
from pyvatti.messagemodels import JobRequest


@lru_cache
Expand Down
41 changes: 39 additions & 2 deletions pyvatti/src/pyvatti/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,23 @@
"""This module contains a state machine that represents job states and their transitions"""

import logging
from typing import Optional
from functools import lru_cache
from typing import Optional, ClassVar

import httpx
from transitions import Machine, EventData, MachineError

from pyvatti.config import settings
from pyvatti.jobstates import States
from pyvatti.models import JobRequest
from pyvatti.messagemodels import JobRequest
from pyvatti.notifymodels import SeqeraLog

from pyvatti.resources import GoogleResourceHandler, DummyResourceHandler

logger = logging.getLogger(__name__)

API_ROOT = "https://api.cloud.seqera.io"


class PolygenicScoreJob(Machine):
"""A state machine for polygenic score calculation jobs
Expand Down Expand Up @@ -108,6 +114,13 @@ class PolygenicScoreJob(Machine):
},
]

# map from destination states to triggers
state_trigger_map: ClassVar[dict] = {
States.FAILED: "error",
States.SUCCEEDED: "succeed",
States.DEPLOYED: "deploy",
}

def __init__(self, intp_id, dry_run=False):
states = [
# a dummy initial state: /launch got POSTed
Expand Down Expand Up @@ -165,5 +178,29 @@ def notify(self, event: Optional[EventData]):
logger.info(f"Sending state notification: {self.state}")
# TODO: add kafka

def get_job_state(self) -> Optional[States]:
"""Get the state of a job by checking the Seqera Platform API
Job state matches the state machine triggers"""
params = {
"workspaceId": settings.TOWER_WORKSPACE,
"search": f"{settings.NAMESPACE}-{self.intp_id}",
}

with httpx.Client() as client:
response = client.get(
f"{API_ROOT}/workflow", headers=get_headers(), params=params
)

return SeqeraLog.from_response(response).get_job_state()

def __repr__(self):
return f"{self.__class__.__name__}(id={self.intp_id!r})"


@lru_cache
def get_headers():
return {
"Authorization": f"Bearer {settings.TOWER_TOKEN}",
"Accept": "application/json",
}
119 changes: 119 additions & 0 deletions pyvatti/src/pyvatti/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,124 @@
import json
import logging
import signal
import sys
import threading
import time
from typing import Optional

import pydantic
import schedule

from pyvatti.config import settings
from pyvatti.db import SqliteJobDatabase

from kafka import KafkaConsumer

from pyvatti.job import PolygenicScoreJob
from pyvatti.jobstates import States
from pyvatti.messagemodels import JobRequest

logger = logging.getLogger()
logger.setLevel(logging.INFO)

JOB_DATABASE = SqliteJobDatabase(settings.SQLITE_DB_PATH)
SHUTDOWN_EVENT = threading.Event()


def check_job_state() -> None:
"""Check the state of the job on the Seqera Platform and update active jobs in the database if the state has changed
Created (resources requested) -> Deployed (running) -> Succeeded / Failed
"""
# active jobs: haven't succeeded or failed
jobs: list[PolygenicScoreJob] = JOB_DATABASE.get_active_jobs()
logger.info(f"{len(jobs)} active jobs found")
for job in jobs:
logger.info(f"Checking {job=} state")
job_state: Optional[States] = job.get_job_state()
if job_state is not None:
if job_state != job.state:
logger.info(
f"Job state change detected: From {job_state} to {job.state}"
)
# get the trigger from the destination state enum
# e.g. "deploy" -> "succeed" / "error"
trigger: str = PolygenicScoreJob.state_trigger_map[job_state]
job.trigger(trigger)
JOB_DATABASE.update_job(job)


def kafka_consumer() -> None:
consumer = KafkaConsumer(
topic=settings.KAFKA_CONSUMER_TOPIC,
bootstrap_servers=settings.KAFKA_BOOTSTRAP_SERVERS,
enable_auto_commit=False,
value_deserializer=lambda m: json.loads(m.decode("ascii")),
)

# want to avoid partially processing a commit if the thread is terminated
try:
for message in consumer:
if SHUTDOWN_EVENT.is_set():
logger.info("Shutdown event received")
break
process_message(message.value)
consumer.commit()
finally:
logger.info("Closing kafka connection")
consumer.close()


def process_message(msg_value: dict) -> None:
"""Each kafka message:
- Gets validated by the pydantic model JobRequest
- Instantiate a PolygenicScoreJob object
- Trigger the "create" state where compute resources are provisioned
- Adds the job object to the database
"""
try:
job_message: JobRequest = JobRequest(**msg_value)
job: PolygenicScoreJob = PolygenicScoreJob(
intp_id=job_message.pipeline_param.id
)
PolygenicScoreJob.create(job_model=job_message)
JOB_DATABASE.insert_job(job)
except pydantic.ValidationError as e:
logger.critical("Job request message validation failed")
logger.critical(f"{e}")


def graceful_shutdown(*args):
logger.info("Shutdown signal received")
SHUTDOWN_EVENT.set()


def main():
# handle shutdowns gracefully (partially processing a job request would be bad)
signal.signal(signal.SIGINT, graceful_shutdown)
signal.signal(signal.SIGTERM, graceful_shutdown)

# create the job database if it does not exist (if it exists, nothing happens here)
JOB_DATABASE.create()

# consume new kafka messages and insert them into the database in a background thread
consumer_thread: threading.Thread = threading.Thread(
target=kafka_consumer, daemon=True
)
consumer_thread.start()

# check for timed out jobs with schedule
schedule.every(15).minutes.do(JOB_DATABASE.timeout_jobs)

# check if job states have changed and produce new messages
schedule.every(1).minutes.do(check_job_state)

# run scheduled tasks:
while True:
schedule.run_pending()
time.sleep(1)


if __name__ == "__main__":
sys.exit(main())
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
""" This module contains pydantic models that represent responses from the Seqera platform API
The platform API is queried to poll and monitor the state of running jobs
"""
import logging
from datetime import datetime
import enum
from functools import lru_cache
from typing import Optional

import httpx
from pydantic import BaseModel, field_serializer
from pydantic import BaseModel, PastDatetime

from pyvatti.config import settings
from pyvatti.jobstates import States
Expand All @@ -16,24 +20,24 @@
logger = logging.getLogger(__name__)


@lru_cache
@lru_cache(maxsize=1)
def get_headers():
"""Headers that authorise querying the Seqera platform API"""
return {
"Authorization": f"Bearer {settings.TOWER_TOKEN}",
"Accept": "application/json",
}


class SeqeraJobStatus(enum.Enum):
class SeqeraJobStatus(str, enum.Enum):
"""Job states on the Seqera platform"""

SUBMITTED = "SUBMITTED"
RUNNING = "RUNNING"
SUCCEEDED = "SUCCEEDED"
FAILED = "FAILED"
UNKNOWN = "UNKNOWN"

def __str__(self):
return str(self.value)


class SeqeraLog(BaseModel):
runName: str
Expand All @@ -49,28 +53,30 @@ def from_response(cls, response: httpx.Response):
else:
return None

def get_job_state(self):
"""Get valid state machine trigger strings from states"""
def get_job_state(self) -> Optional[States]:
"""Get valid job states"""
match self.status:
case SeqeraJobStatus.SUCCEEDED:
state = "succeed"
state = States.SUCCEEDED
case SeqeraJobStatus.FAILED | SeqeraJobStatus.UNKNOWN:
state = "error"
state = States.FAILED
case SeqeraJobStatus.RUNNING:
state = "deploy"
state = States.DEPLOYED
case _:
logger.warning(f"Unknown state: {self.status}")
raise Exception
state = None
return state


class BackendStatusMessage(BaseModel):
"""A message updating the backend about job state"""
"""A message updating the backend about job state
>>> from datetime import datetime
>>> d = {"run_name": "INTP123456", "utc_time": datetime(1999, 12, 31), "event": States.SUCCEEDED}
>>> BackendStatusMessage(**d).model_dump_json()
'{"run_name":"INTP123456","utc_time":"1999-12-31T00:00:00","event":"succeeded"}'
"""

run_name: str
utc_time: datetime
utc_time: PastDatetime
event: States

@field_serializer("event")
def event_field(self, event):
return str(event)
10 changes: 10 additions & 0 deletions pyvatti/src/pyvatti/render_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import sys


def main():
# TODO: build a CLI that renders a helm template from a message for testing
pass


if __name__ == "__main__":
sys.exit(main())
Loading

0 comments on commit 7616343

Please sign in to comment.