Skip to content

Commit

Permalink
Add just command helper tool to repository
Browse files Browse the repository at this point in the history
  • Loading branch information
andreas-el authored and jonathan-eq committed Nov 12, 2024
1 parent b5d3671 commit 710af8e
Show file tree
Hide file tree
Showing 11 changed files with 170 additions and 100 deletions.
13 changes: 13 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# configuration for for `just`

# run poly example test-case
poly:
ert gui test-data/ert/poly_example/poly.ert

# run snake oil test-case
snake_oil:
ert gui test-data/ert/snake_oil/snake_oil.ert

# execute rapid unittests
rapid-tests:
pytest -n logical tests/ert/unit_tests -m "not integration_tests"
51 changes: 36 additions & 15 deletions src/_ert/forward_model_runner/cli.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import asyncio
import json
import logging
import os
Expand Down Expand Up @@ -90,7 +91,7 @@ def _read_jobs_file(retry=True):
raise e


def main(args):
async def main(args):
parser = argparse.ArgumentParser(
description=(
"Run all the jobs specified in jobs.json, "
Expand Down Expand Up @@ -137,19 +138,39 @@ def main(args):
)

job_runner = ForwardModelRunner(jobs_data)
job_task = asyncio.create_task(_main(job_runner, parsed_args, reporters))

for job_status in job_runner.run(parsed_args.job):
logger.info(f"Job status: {job_status}")
def handle_sigterm(_, __):
nonlocal reporters, job_task
print("CALLED SIGTERM")
job_task.cancel()
for reporter in reporters:
try:
reporter.report(job_status)
except OSError as oserror:
print(
f"job_dispatch failed due to {oserror}. Stopping and cleaning up."
)
pgid = os.getpgid(os.getpid())
os.killpg(pgid, signal.SIGKILL)

if isinstance(job_status, Finish) and not job_status.success():
pgid = os.getpgid(os.getpid())
os.killpg(pgid, signal.SIGKILL)
reporter.cancel_synced()

signal.signal(signal.SIGTERM, handle_sigterm)
await job_task


async def _main(
job_runner: ForwardModelRunner,
parsed_args,
reporters: typing.Sequence[reporting.Reporter],
):
try:
async for job_status in job_runner.run(parsed_args.job):
logger.info(f"Job status: {job_status}")

for reporter in reporters:
try:
await reporter.report(job_status)
await asyncio.sleep(0)
except OSError as oserror:
print(
f"job_dispatch failed due to {oserror}. Stopping and cleaning up."
)
return

if isinstance(job_status, Finish) and not job_status.success():
return
except asyncio.CancelledError:
pass
5 changes: 1 addition & 4 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ async def get_websocket(self) -> WebSocketClientProtocol:
close_timeout=self.CONNECTION_TIMEOUT,
)

async def _send(self, msg: AnyStr) -> None:
async def send(self, msg: AnyStr) -> None:
for retry in range(self._max_retries + 1):
try:
if self.websocket is None:
Expand Down Expand Up @@ -133,6 +133,3 @@ async def _send(self, msg: AnyStr) -> None:
raise ClientConnectionError(_error_msg) from exception
await asyncio.sleep(0.2 + self._timeout_multiplier * retry)
self.websocket = None

def send(self, msg: AnyStr) -> None:
self.loop.run_until_complete(self._send(msg))
8 changes: 2 additions & 6 deletions src/_ert/forward_model_runner/job_dispatch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import os
import signal
import sys
Expand All @@ -13,12 +14,7 @@ def sigterm_handler(_signo, _stack_frame):
def main():
os.nice(19)
signal.signal(signal.SIGTERM, sigterm_handler)
try:
job_runner_main(sys.argv)
except Exception as e:
pgid = os.getpgid(os.getpid())
os.killpg(pgid, signal.SIGTERM)
raise e
asyncio.run(job_runner_main(sys.argv))


if __name__ == "__main__":
Expand Down
10 changes: 9 additions & 1 deletion src/_ert/forward_model_runner/reporting/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,13 @@

class Reporter(ABC):
@abstractmethod
def report(self, msg: Message):
async def report(self, msg: Message):
"""Report a message."""

@abstractmethod
async def cancel(self):
"""Safely shut down the reporter"""

@abstractmethod
def cancel_synced(self):
"""Safely shut down the reporter"""
70 changes: 42 additions & 28 deletions src/_ert/forward_model_runner/reporting/event.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import asyncio
import logging
import queue
import threading
from datetime import datetime, timedelta
from pathlib import Path
Expand Down Expand Up @@ -32,7 +32,6 @@
Start,
)
from _ert.forward_model_runner.reporting.statemachine import StateMachine
from _ert.threading import ErtThread

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -76,22 +75,24 @@ def __init__(self, evaluator_url, token=None, cert_path=None):

self._ens_id = None
self._real_id = None
self._event_queue: queue.Queue[events.Event | EventSentinel] = queue.Queue()
self._event_publisher_thread = ErtThread(target=self._event_publisher)
self._event_queue: asyncio.Queue[events.Event | EventSentinel] = asyncio.Queue()
# self._event_publisher_thread = ErtThread(target=self._event_publisher)
self._timeout_timestamp = None
self._timestamp_lock = threading.Lock()
# seconds to timeout the reporter the thread after Finish() was received
self._reporter_timeout = 60
self._running = True
self._event_publishing_task = asyncio.create_task(self.async_event_publisher())

def _event_publisher(self):
async def async_event_publisher(self):
logger.debug("Publishing event.")
with Client(
async with Client(
url=self._evaluator_url,
token=self._token,
cert=self._cert,
) as client:
event = None
while True:
while self._running:
with self._timestamp_lock:
if (
self._timeout_timestamp is not None
Expand All @@ -102,11 +103,17 @@ def _event_publisher(self):
if event is None:
# if we successfully sent the event we can proceed
# to next one
event = self._event_queue.get()
print("GETTING NEW EVENT FROM QUEUE")
event = await self._event_queue.get()
print("GOT NEW EVENT FORM QUEUE")
if event is self._sentinel:
break
print("NEW EVENT WAS SENTINEL :))")
return
else:
print(f"{event=}")
try:
client.send(event_to_json(event))
await client.send(event_to_json(event))
self._event_queue.task_done()
event = None
except ClientConnectionError as exception:
# Possible intermittent failure, we retry sending the event
Expand All @@ -115,21 +122,21 @@ def _event_publisher(self):
# The receiving end has closed the connection, we stop
# sending events
logger.debug(str(exception))
self._event_queue.task_done()
break

def report(self, msg):
self._statemachine.transition(msg)
async def report(self, msg):
await self._statemachine.transition(msg)

def _dump_event(self, event: events.Event):
async def _dump_event(self, event: events.Event):
logger.debug(f'Schedule "{type(event)}" for delivery')
self._event_queue.put(event)
await self._event_queue.put(event)

def _init_handler(self, msg: Init):
async def _init_handler(self, msg: Init):
self._ens_id = str(msg.ens_id)
self._real_id = str(msg.real_id)
self._event_publisher_thread.start()

def _job_handler(self, msg: Union[Start, Running, Exited]):
async def _job_handler(self, msg: Union[Start, Running, Exited]):
assert msg.job
job_name = msg.job.name()
job_msg = {
Expand All @@ -144,16 +151,18 @@ def _job_handler(self, msg: Union[Start, Running, Exited]):
std_out=str(Path(msg.job.std_out).resolve()),
std_err=str(Path(msg.job.std_err).resolve()),
)
self._dump_event(event)
print("DUMPING EVENT")
await self._dump_event(event)
print("DUMPED EVENT")
if not msg.success():
logger.error(f"Job {job_name} FAILED to start")
event = ForwardModelStepFailure(**job_msg, error_msg=msg.error_message)
self._dump_event(event)
await self._dump_event(event)

elif isinstance(msg, Exited):
if msg.success():
logger.debug(f"Job {job_name} exited successfully")
self._dump_event(ForwardModelStepSuccess(**job_msg))
await self._dump_event(ForwardModelStepSuccess(**job_msg))
else:
logger.error(
_JOB_EXIT_FAILED_STRING.format(
Expand All @@ -165,7 +174,7 @@ def _job_handler(self, msg: Union[Start, Running, Exited]):
event = ForwardModelStepFailure(
**job_msg, exit_code=msg.exit_code, error_msg=msg.error_message
)
self._dump_event(event)
await self._dump_event(event)

elif isinstance(msg, Running):
logger.debug(f"{job_name} job is running")
Expand All @@ -175,21 +184,26 @@ def _job_handler(self, msg: Union[Start, Running, Exited]):
current_memory_usage=msg.memory_status.rss,
cpu_seconds=msg.memory_status.cpu_seconds,
)
self._dump_event(event)
await self._dump_event(event)

def _finished_handler(self, _):
self._event_queue.put(Event._sentinel)
async def _finished_handler(self, _):
await self._event_queue.put(Event._sentinel)
with self._timestamp_lock:
self._timeout_timestamp = datetime.now() + timedelta(
seconds=self._reporter_timeout
)
if self._event_publisher_thread.is_alive():
self._event_publisher_thread.join()

def _checksum_handler(self, msg: Checksum):
async def _checksum_handler(self, msg: Checksum):
fm_checksum = ForwardModelStepChecksum(
ensemble=self._ens_id,
real=self._real_id,
checksums={msg.run_path: msg.data},
)
self._dump_event(fm_checksum)
await self._dump_event(fm_checksum)

def cancel_synced(self):
self._cancel_task = asyncio.create_task(self.cancel())

async def cancel(self):
await self._event_queue.put(Event._sentinel)
await self._event_publishing_task
8 changes: 7 additions & 1 deletion src/_ert/forward_model_runner/reporting/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self):
self.status_dict = {}
self.node = socket.gethostname()

def report(self, msg: Message):
async def report(self, msg: Message):
fm_step_status = {}

if msg.job:
Expand Down Expand Up @@ -217,3 +217,9 @@ def _dump_ok_file():
def _dump_status_json(self):
with open(STATUS_json, "wb") as fp:
fp.write(orjson.dumps(self.status_dict, option=orjson.OPT_INDENT_2))

def cancel_synced(self):
pass

async def cancel(self):
pass
12 changes: 9 additions & 3 deletions src/_ert/forward_model_runner/reporting/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class Interactive(Reporter):
@staticmethod
def _report(msg: Message) -> Optional[str]:
async def _report(msg: Message) -> Optional[str]:
if not isinstance(msg, (Start, Finish)):
return None
if isinstance(msg, Finish):
Expand All @@ -26,7 +26,13 @@ def _report(msg: Message) -> Optional[str]:
)
return f"Running job: {msg.job.name()} ... "

def report(self, msg: Message):
_msg = self._report(msg)
async def report(self, msg: Message):
_msg = await self._report(msg)
if _msg is not None:
print(_msg)

async def cancel(self):
pass

def cancel_synced(self):
pass
10 changes: 6 additions & 4 deletions src/_ert/forward_model_runner/reporting/statemachine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Callable, Dict, Tuple, Type
from typing import Awaitable, Callable, Dict, Tuple, Type

from _ert.forward_model_runner.reporting.message import (
Checksum,
Expand Down Expand Up @@ -35,13 +35,15 @@ def __init__(self) -> None:
self._state = None

def add_handler(
self, states: Tuple[Type[Message], ...], handler: Callable[[Message], None]
self,
states: Tuple[Type[Message], ...],
handler: Callable[[Message], Awaitable[None]],
) -> None:
if states in self._handler:
raise ValueError(f"{states} already handled by {self._handler[states]}")
self._handler[states] = handler

def transition(self, message: Message):
async def transition(self, message: Message):
new_state = None
for state in self._handler:
if isinstance(message, state):
Expand All @@ -58,5 +60,5 @@ def transition(self, message: Message):
f"expected to transition into {self._transitions[self._state]}"
)

self._handler[new_state](message)
await self._handler[new_state](message)
self._state = new_state
Loading

0 comments on commit 710af8e

Please sign in to comment.