Skip to content

Commit

Permalink
Implement a nicer interactive frontend to show which steps are executing
Browse files Browse the repository at this point in the history
This allows users to have a more dynamic view of what is happening and
overral gives better feedback
  • Loading branch information
BenjaminSchubert committed Feb 11, 2023
1 parent e34a16f commit 8cb3536
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 22 deletions.
6 changes: 5 additions & 1 deletion src/dwas/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ def __init__(
n_jobs = multiprocessing.cpu_count()
self.n_jobs = n_jobs

self.is_interactive = (
sys.__stdout__.isatty() and sys.__stderr__.isatty()
)

self.environ = {
# XXX: keep this list in sync with the above documentation
key: os.environ[key]
Expand Down Expand Up @@ -225,7 +229,7 @@ def _get_color_setting(self, colors: Optional[bool]) -> bool:
if "GITHUB_ACTION" in os.environ:
return True

return sys.stdin.isatty()
return self.is_interactive

def _prepare_and_clean_log_path(self) -> None:
self.log_path.mkdir(parents=True, exist_ok=True)
Expand Down
146 changes: 146 additions & 0 deletions src/dwas/_frontend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import shutil
import sys
from contextlib import contextmanager
from contextvars import copy_context
from datetime import datetime
from threading import Event, Lock, Thread
from typing import Dict, Iterator, List

from colorama import Cursor, Fore, ansi

from . import _io
from ._timing import format_timedelta


class StepSummary:
def __init__(self, all_steps: List[str]) -> None:
self._running_steps: Dict[str, datetime] = {}
self._lock = Lock()

self._start = datetime.now()

self._n_success = 0
self._n_failure = 0
self._waiting = all_steps

def mark_running(self, step: str) -> None:
with self._lock:
self._running_steps[step] = datetime.now()
self._waiting.remove(step)

def mark_success(self, step: str) -> None:
with self._lock:
del self._running_steps[step]
self._n_success += 1

def mark_failure(self, step: str) -> None:
with self._lock:
del self._running_steps[step]
self._n_failure += 1

def lines(self) -> List[str]:
update_at = datetime.now()

# 40 comes from the number of color codes * 5, as this is what is added
# to the real length of the array
term_width = shutil.get_terminal_size().columns + 40
headline = (
f" {Fore.YELLOW}Runtime: {format_timedelta(update_at - self._start)} "
f"["
f"{len(self._waiting)}/"
f"{Fore.CYAN}{len(self._running_steps)}{Fore.YELLOW}/"
f"{Fore.GREEN}{self._n_success}{Fore.YELLOW}/"
f"{Fore.RED}{self._n_failure}{Fore.YELLOW}"
f"]{Fore.RESET} "
).center(term_width, "~")

return (
[headline]
+ [
f"[{format_timedelta(update_at - since)}] {Fore.CYAN}{step}: running{Fore.RESET}"
for step, since in self._running_steps.items()
]
+ [
f"[-:--:--] {Fore.YELLOW}waiting: {' '.join(self._waiting)}{Fore.RESET}"
]
)


class Frontend:
def __init__(self, summary: StepSummary) -> None:
self._summary = summary

def _refresh_in_context() -> None:
with _io.redirect_streams(
sys.__stdout__, sys.__stderr__
), _io.log_file(None):
self._refresh()

self._refresh_thread = Thread(
target=copy_context().run, args=[_refresh_in_context]
)
self._stop = Event()

self._pipe_plexer = _io.PipePlexer(write_on_flush=False)

@contextmanager
def activate(self) -> Iterator[None]:
with _io.redirect_streams(
self._pipe_plexer.stdout, self._pipe_plexer.stderr
):
self._refresh_thread.start()

try:
yield
finally:
self._stop.set()
self._refresh_thread.join()

def _refresh(self) -> None:
previous_progress_height = 0
previous_last_line_length = 0

def refresh(skip_summary: bool = False) -> None:
nonlocal previous_progress_height
nonlocal previous_last_line_length

# Erase the current line
if previous_last_line_length != 0:
sys.stderr.write(
Cursor.BACK(previous_last_line_length) + ansi.clear_line()
)

# Erase the previous summary lines
if previous_progress_height >= 2:
sys.stderr.write(
f"{Cursor.UP(1)}{ansi.clear_line()}"
* (previous_progress_height - 1)
)

# Force a flush, to ensure that if the next line is printed on
# stdout, we pass the erasing first
sys.stderr.flush()

self._pipe_plexer.flush(force_write=True)

if skip_summary:
previous_last_line_length = 0
previous_progress_height = 0
else:
summary = self._summary.lines()

sys.stderr.write(
ansi.clear_line() + f"\n{ansi.clear_line()}".join(summary)
)
previous_progress_height = len(summary)
if previous_progress_height:
previous_last_line_length = len(summary[-1])

sys.stderr.flush()

refresh()
while not self._stop.is_set():
self._stop.wait(0.5)
refresh()

refresh(True)
23 changes: 14 additions & 9 deletions src/dwas/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,24 +52,29 @@ def flush(self) -> None:


class PipePlexer:
def __init__(self) -> None:
def __init__(self, write_on_flush: bool = True) -> None:
self.stderr = MemoryPipe(self)
self.stdout = MemoryPipe(self)

self._buffer: deque[Tuple[MemoryPipe, str]] = deque()
self._write_on_flush = write_on_flush

def write(self, stream: MemoryPipe, data: str) -> int:
self._buffer.append((stream, data))
return len(data)

def flush(self) -> None:
with suppress(IndexError):
while True:
stream, line = self._buffer.popleft()
if stream == self.stdout:
sys.stdout.write(line)
else:
sys.stderr.write(line)
def flush(self, force_write: bool = False) -> None:
if self._write_on_flush or force_write:
with suppress(IndexError):
while True:
stream, line = self._buffer.popleft()
if stream == self.stdout:
sys.stdout.write(line)
else:
sys.stderr.write(line)

sys.stdout.flush()
sys.stderr.flush()


class StreamHandler(io.TextIOWrapper):
Expand Down
35 changes: 27 additions & 8 deletions src/dwas/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
UnavailableInterpreterException,
UnknownStepsException,
)
from ._frontend import Frontend, StepSummary
from ._subproc import ProcessManager
from ._timing import format_timedelta, get_timedelta_since

Expand Down Expand Up @@ -342,9 +343,18 @@ def request_stop(_signum: int, _frame: Optional[FrameType]) -> None:
previous_signal = signal.signal(signal.SIGINT, request_stop)

try:
results = self._execute(
sorter, running_futures, stop, lambda: should_stop
)
summary = StepSummary(list(graph))
with ExitStack() as stack:
if self.config.is_interactive:
stack.enter_context(Frontend(summary).activate())

results = self._execute(
sorter,
running_futures,
stop,
lambda: should_stop,
summary,
)
finally:
signal.signal(signal.SIGINT, previous_signal)

Expand Down Expand Up @@ -422,6 +432,7 @@ def _execute(
],
stop: Callable[[], None],
should_stop: Callable[[], bool],
summary: StepSummary,
) -> Dict[str, Tuple[Optional[Exception], timedelta]]:
results: Dict[str, Tuple[Optional[Exception], timedelta]] = {}

Expand Down Expand Up @@ -483,10 +494,11 @@ def _execute(
executor.submit(
# XXX: mypy gets confused here, but the result is
# sane
copy_context().run, # type: ignore
self._run_step, # type: ignore
name, # type: ignore
pipe_plexer, # type: ignore
copy_context().run, # type: ignore[arg-type]
self._run_step, # type: ignore[arg-type]
name, # type: ignore[arg-type]
pipe_plexer, # type: ignore[arg-type]
summary, # type: ignore[arg-type]
),
)
running_futures[future] = name, pipe_plexer
Expand Down Expand Up @@ -590,6 +602,7 @@ def _run_step(
self,
name: str,
pipe_plexer: Optional[_io.PipePlexer],
summary: StepSummary,
) -> timedelta:
with ExitStack() as stack:
if pipe_plexer is not None:
Expand All @@ -607,7 +620,13 @@ def _run_step(
LOGGER.debug("Log file can be found at %s", log_file)
stack.enter_context(_io.log_file(log_file))

time_taken = self._run_step_with_logging(name)
summary.mark_running(name)
try:
time_taken = self._run_step_with_logging(name)
except Exception:
summary.mark_failure(name)
raise
summary.mark_success(name)

return time_taken

Expand Down
14 changes: 10 additions & 4 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=redefined-outer-name

import itertools
import os

import pytest
Expand Down Expand Up @@ -45,12 +46,17 @@ def test_can_control_colors_explicitly(enable, kwargs):
)


@pytest.mark.parametrize("is_tty", (True, False))
def test_enables_colors_if_tty(monkeypatch, is_tty, kwargs):
monkeypatch.setattr("sys.stdin.isatty", lambda: is_tty)
@pytest.mark.parametrize(
("stdout_is_tty", "stderr_is_tty"), itertools.permutations([True, False])
)
def test_enables_colors_if_tty(
monkeypatch, stdout_is_tty, stderr_is_tty, kwargs
):
monkeypatch.setattr("sys.stdout.isatty", lambda: stdout_is_tty)
monkeypatch.setattr("sys.stderr.isatty", lambda: stderr_is_tty)

conf = Config(**kwargs, colors=None)
assert conf.colors == is_tty
assert conf.colors == (stdout_is_tty and stderr_is_tty)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 8cb3536

Please sign in to comment.