diff --git a/jupyter_client/ioloop/restarter.py b/jupyter_client/ioloop/restarter.py index d0c70396..6a3b9fe2 100644 --- a/jupyter_client/ioloop/restarter.py +++ b/jupyter_client/ioloop/restarter.py @@ -55,9 +55,9 @@ async def poll(self): """Poll the kernel.""" if self.debug: self.log.debug("Polling kernel...") - is_alive = await self.kernel_manager.is_alive() + exit_status = await self.kernel_manager.exit_status() now = time.time() - if not is_alive: + if exit_status is not None: self._last_dead = now if self._restarting: self._restart_count += 1 @@ -66,7 +66,7 @@ async def poll(self): if self._restart_count > self.restart_limit: self.log.warning("AsyncIOLoopKernelRestarter: restart failed") - self._fire_callbacks("dead") + self._fire_callbacks("dead", exit_status) self._restarting = False self._restart_count = 0 self.stop() @@ -78,7 +78,7 @@ async def poll(self): self.restart_limit, "new" if newports else "keep", ) - self._fire_callbacks("restart") + self._fire_callbacks("restart", exit_status) await self.kernel_manager.restart_kernel(now=True, newports=newports) self._restarting = True else: diff --git a/jupyter_client/manager.py b/jupyter_client/manager.py index f04bd987..ecd69a6d 100644 --- a/jupyter_client/manager.py +++ b/jupyter_client/manager.py @@ -231,11 +231,17 @@ def stop_restarter(self) -> None: """Stop the kernel restarter.""" pass - def add_restart_callback(self, callback: t.Callable, event: str = "restart") -> None: + def add_restart_callback( + self, + callback: t.Callable[[], object] | t.Callable[[int], object], + event: str = "restart", + *, + accepts_exit_code: bool = False, + ) -> None: """Register a callback to be called when a kernel is restarted""" if self._restarter is None: return - self._restarter.add_callback(callback, event) + self._restarter.add_callback(callback, event, accepts_exit_code=accepts_exit_code) def remove_restart_callback(self, callback: t.Callable, event: str = "restart") -> None: """Unregister a callback to be called when a kernel is restarted""" @@ -655,6 +661,17 @@ async def _async_is_alive(self) -> bool: is_alive = run_sync(_async_is_alive) + async def _async_exit_status(self) -> t.Optional[int]: + """Returns 0 if there's no kernel or it exited gracefully, + None if the kernel is running, or a negative value `-N` if the + kernel was killed by signal `N` (posix only).""" + if not self.has_kernel: + return 0 + assert self.provisioner is not None + return await self.provisioner.poll() + + exit_status = run_sync(_async_exit_status) + async def _async_wait(self, pollinterval: float = 0.1) -> None: # Use busy loop at 100ms intervals, polling until the process is # not alive. If we find the process is no longer alive, complete diff --git a/jupyter_client/restarter.py b/jupyter_client/restarter.py index 194ba907..fcdcc411 100644 --- a/jupyter_client/restarter.py +++ b/jupyter_client/restarter.py @@ -8,6 +8,7 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. import time +import typing as t from traitlets import Bool, Dict, Float, Instance, Integer, default from traitlets.config.configurable import LoggingConfigurable @@ -55,7 +56,8 @@ class KernelRestarter(LoggingConfigurable): def _default_last_dead(self): return time.time() - callbacks = Dict() + # traitlets.Dict is not typed generic + callbacks: t.Dict[str, t.List[t.Tuple[t.Callable[[int], object], t.Literal[True]] | t.Tuple[t.Callable[[], object], t.Literal[False]]]] = Dict() # type: ignore[assignment] def _callbacks_default(self): return {"restart": [], "dead": []} @@ -70,8 +72,34 @@ def stop(self): msg = "Must be implemented in a subclass" raise NotImplementedError(msg) - def add_callback(self, f, event="restart"): - """register a callback to fire on a particular event + @t.overload + def add_callback( + self, + f: t.Callable[[int], object], + event: str = "restart", + *, + accepts_exit_code: t.Literal[True], + ) -> None: + ... + + @t.overload + def add_callback( + self, + f: t.Callable[[], object], + event: str = "restart", + *, + accepts_exit_code: t.Literal[False] = False, + ) -> None: + ... + + def add_callback( + self, + f: t.Callable[[], object] | t.Callable[[int], object], + event: str = "restart", + *, + accepts_exit_code: bool = False, + ) -> None: + """register a callback to fire on a particular event. If ``accepts_exit_code`` is set, the callable will be passed the exit code as reported by `KernelManager.exit_status` Possible values for event: @@ -79,7 +107,8 @@ def add_callback(self, f, event="restart"): 'dead': restart has failed, kernel will be left dead. """ - self.callbacks[event].append(f) + # the type correlation from overloads is not tracked to here by mypy + self.callbacks[event].append((f, accepts_exit_code)) # type: ignore[arg-type] def remove_callback(self, f, event="restart"): """unregister a callback to fire on a particular event @@ -95,16 +124,20 @@ def remove_callback(self, f, event="restart"): except ValueError: pass - def _fire_callbacks(self, event): + def _fire_callbacks(self, event, status): """fire our callbacks for a particular event""" + # unpacking in the loop breaks the connection between the variables for mypy for callback in self.callbacks[event]: try: - callback() + if callback[1] is True: + callback[0](status) + else: + callback[0]() except Exception: self.log.error( "KernelRestarter: %s callback %r failed", event, - callback, + callback[0], exc_info=True, ) @@ -115,7 +148,8 @@ def poll(self): self.log.debug("Kernel shutdown in progress...") return now = time.time() - if not self.kernel_manager.is_alive(): + status = self.kernel_manager.exit_status() + if status is not None: self._last_dead = now if self._restarting: self._restart_count += 1 @@ -124,7 +158,7 @@ def poll(self): if self._restart_count > self.restart_limit: self.log.warning("KernelRestarter: restart failed") - self._fire_callbacks("dead") + self._fire_callbacks("dead", status) self._restarting = False self._restart_count = 0 self.stop() @@ -136,7 +170,7 @@ def poll(self): self.restart_limit, "new" if newports else "keep", ) - self._fire_callbacks("restart") + self._fire_callbacks("restart", status) self.kernel_manager.restart_kernel(now=True, newports=newports) self._restarting = True else: diff --git a/tests/test_kernelmanager.py b/tests/test_kernelmanager.py index f2d749eb..986bdda1 100644 --- a/tests/test_kernelmanager.py +++ b/tests/test_kernelmanager.py @@ -160,6 +160,46 @@ async def test_async_signal_kernel_subprocesses(self, name, install, expected): assert km._shutdown_status in expected +class TestKernelManagerExitStatus: + @pytest.mark.skipif(sys.platform == "win32", reason="Windows doesn't support signals") + @pytest.mark.parametrize('_signal', [signal.SIGILL, signal.SIGSEGV, signal.SIGTERM]) + async def test_exit_status(self, _signal): + # install kernel + _install_kernel(name="test_exit_status") + + # start kernel + km, kc = start_new_kernel(kernel_name="test_exit_status") + + # stop restarter - not needed? + # km.stop_restarter() + + # check that process is running + assert km.exit_status() is None + + # get the provisioner + # send signal + provisioner = km.provisioner + assert provisioner is not None + assert provisioner.has_process + await provisioner.send_signal(_signal) + + # wait for the process to exit + try: + await asyncio.wait_for(km._async_wait(), timeout=3.0) + except TimeoutError: + assert False, f'process never stopped for signal {signal}' + + # check that the signal is correct + assert km.exit_status() == -_signal + + # doing a proper shutdown now wipes the status, might be bad? + km.shutdown_kernel(now=True) + assert km.exit_status() == 0 + + # stop channels so cleanup doesn't complain + kc.stop_channels() + + class TestKernelManager: def test_lifecycle(self, km): km.start_kernel(stdout=PIPE, stderr=PIPE) diff --git a/tests/test_restarter.py b/tests/test_restarter.py index b216842f..078a300e 100644 --- a/tests/test_restarter.py +++ b/tests/test_restarter.py @@ -4,7 +4,9 @@ import asyncio import json import os +import signal import sys +import typing as t from concurrent.futures import Future import pytest @@ -88,16 +90,16 @@ def debug_logging(): @win_skip async def test_restart_check(config, install_kernel, debug_logging): """Test that the kernel is restarted and recovers""" - # If this test failes, run it with --log-cli-level=DEBUG to inspect + # If this test fails, run it with --log-cli-level=DEBUG to inspect N_restarts = 1 config.KernelRestarter.restart_limit = N_restarts config.KernelRestarter.debug = True km = IOLoopKernelManager(kernel_name=install_kernel, config=config) cbs = 0 - restarts: list = [Future() for i in range(N_restarts)] + restarts: t.List[Future[bool]] = [Future() for i in range(N_restarts)] - def cb(): + def cb() -> None: nonlocal cbs if cbs >= N_restarts: raise RuntimeError("Kernel restarted more than %d times!" % N_restarts) @@ -141,10 +143,66 @@ def cb(): assert km.context.closed +@win_skip +async def test_restart_check_exit_status(config, install_kernel, debug_logging): + """Test that the kernel is restarted and recovers, and validates the exit code.""" + # If this test fails, run it with --log-cli-level=DEBUG to inspect + N_restarts = 1 + config.KernelRestarter.restart_limit = N_restarts + config.KernelRestarter.debug = True + km = IOLoopKernelManager(kernel_name=install_kernel, config=config) + + cbs = 0 + restarts: t.List[Future[int]] = [Future() for i in range(N_restarts)] + + def cb(exit_status: int) -> None: + nonlocal cbs + if cbs >= N_restarts: + raise RuntimeError("Kernel restarted more than %d times!" % N_restarts) + restarts[cbs].set_result(exit_status) + cbs += 1 + + try: + km.start_kernel() + km.add_restart_callback(cb, 'restart', accepts_exit_code=True) + except BaseException: + if km.has_kernel: + km.shutdown_kernel() + raise + + try: + for i in range(N_restarts + 1): + kc = km.client() + kc.start_channels() + kc.wait_for_ready(timeout=60) + kc.stop_channels() + if i < N_restarts: + # Kill without cleanup to simulate crash: + assert km.provisioner is not None + await km.provisioner.kill() + assert restarts[i].result() == -signal.SIGKILL + # Wait for kill + restart + max_wait = 10.0 + waited = 0.0 + while waited < max_wait and km.is_alive(): + await asyncio.sleep(0.1) + waited += 0.1 + while waited < max_wait and not km.is_alive(): + await asyncio.sleep(0.1) + waited += 0.1 + + assert cbs == N_restarts + assert km.is_alive() + + finally: + km.shutdown_kernel(now=True) + assert km.context.closed + + @win_skip async def test_restarter_gives_up(config, install_fail_kernel, debug_logging): """Test that the restarter gives up after reaching the restart limit""" - # If this test failes, run it with --log-cli-level=DEBUG to inspect + # If this test fails, run it with --log-cli-level=DEBUG to inspect N_restarts = 1 config.KernelRestarter.restart_limit = N_restarts config.KernelRestarter.debug = True @@ -188,7 +246,7 @@ def on_death(): async def test_async_restart_check(config, install_kernel, debug_logging): """Test that the kernel is restarted and recovers""" - # If this test failes, run it with --log-cli-level=DEBUG to inspect + # If this test fails, run it with --log-cli-level=DEBUG to inspect N_restarts = 1 config.KernelRestarter.restart_limit = N_restarts config.KernelRestarter.debug = True @@ -243,7 +301,7 @@ def cb(): async def test_async_restarter_gives_up(config, install_slow_fail_kernel, debug_logging): """Test that the restarter gives up after reaching the restart limit""" - # If this test failes, run it with --log-cli-level=DEBUG to inspect + # If this test fails, run it with --log-cli-level=DEBUG to inspect N_restarts = 2 config.KernelRestarter.restart_limit = N_restarts config.KernelRestarter.debug = True