Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decouple broker and coordinator interface #6675

Draft
wants to merge 23 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e8856a0
ProcessFuture may okay
unkcpz Dec 19, 2024
a1e8f87
Debug enhancer
unkcpz Dec 19, 2024
b837d72
fix
unkcpz Dec 19, 2024
92f1683
plumpy.futures.CancelledError is an alias of concurrent.futures.Cance…
unkcpz Dec 20, 2024
6085bca
Adopt with new message type and solve import issues
unkcpz Dec 20, 2024
61934f3
apply timeout for pytest to 30s
unkcpz Dec 20, 2024
d5733b2
use aiida_profile_clean for test_input_code test
unkcpz Dec 20, 2024
b1f446a
bind and test against the corresponded rmq-out branch of unkcpz/plump…
unkcpz Dec 20, 2024
36f1a86
Simplipy create_runner function signature
unkcpz Dec 21, 2024
28cdb1c
Runner use coordinator interface
unkcpz Dec 21, 2024
5d59e6a
Remove get_communicator calls
unkcpz Dec 21, 2024
c769906
Controller snuck into broker
unkcpz Dec 21, 2024
03f7a5b
Construct and use RmqLooCoordinator directly
unkcpz Dec 27, 2024
02a939e
Separate create_broker and get_broker where get won't change state
unkcpz Dec 27, 2024
329c51c
Keep on mess up with coordinator loop
unkcpz Dec 27, 2024
82f14dd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 27, 2024
2c00fb7
coordinator decouple usage in process.py
unkcpz Dec 28, 2024
1857656
Move create_daemon_runner to daemon worker.py
unkcpz Dec 28, 2024
01f92f7
Remove outer usage of runner.coordinator
unkcpz Dec 28, 2024
69db549
Start runner in a dedicated thread
unkcpz Dec 28, 2024
a078400
Exclude .python-version of pyenv from gitignore
unkcpz Dec 29, 2024
5746ae8
find run_task not await coro crux
unkcpz Dec 29, 2024
5572660
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,6 @@ pplot_out/

# docker
docker-bake.override.json

# pyenv
.python-version
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dependencies:
- importlib-metadata~=6.0
- numpy~=1.21
- paramiko~=3.0
- plumpy~=0.22.3
- plumpy
- pgsu~=0.3.0
- psutil~=5.6
- psycopg[binary]~=3.0
Expand Down
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
'importlib-metadata~=6.0',
'numpy~=1.21',
'paramiko~=3.0',
'plumpy~=0.22.3',
'plumpy',
'pgsu~=0.3.0',
'psutil~=5.6',
'psycopg[binary]~=3.0',
Expand Down Expand Up @@ -246,6 +246,7 @@ tests = [
'pympler~=1.0',
'coverage~=7.0',
'sphinx~=7.2.0',
'watchdog~=6.0',
'docutils~=0.20'
]
tui = [
Expand Down Expand Up @@ -387,6 +388,7 @@ minversion = '7.0'
testpaths = [
'tests'
]
timeout = 30
xfail_strict = true

[tool.ruff]
Expand Down Expand Up @@ -509,3 +511,6 @@ passenv =
AIIDA_TEST_WORKERS
commands = molecule {posargs:test}
"""

[tool.uv.sources]
plumpy = {git = "https://github.com/unkcpz/plumpy", branch = "rmq-out"}
26 changes: 18 additions & 8 deletions src/aiida/brokers/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,35 @@
import abc
import typing as t

from plumpy.controller import ProcessController

if t.TYPE_CHECKING:
from aiida.manage.configuration.profile import Profile
from plumpy.coordinator import Coordinator


__all__ = ('Broker',)


# FIXME: make me a protocol
class Broker:
"""Interface for a message broker that facilitates communication with and between process runners."""

def __init__(self, profile: 'Profile') -> None:
"""Construct a new instance.
# def __init__(self, profile: 'Profile') -> None:
# """Construct a new instance.
#
# :param profile: The profile.
# """
# self._profile = profile

:param profile: The profile.
"""
self._profile = profile
@abc.abstractmethod
# FIXME: make me a property
def get_coordinator(self) -> 'Coordinator':
"""Return an instance of coordinator."""

@abc.abstractmethod
def get_communicator(self):
"""Return an instance of :class:`kiwipy.Communicator`."""
def get_controller(self) -> ProcessController:
"""Return the process controller"""
...

@abc.abstractmethod
def iterate_tasks(self):
Expand Down
35 changes: 27 additions & 8 deletions src/aiida/brokers/rabbitmq/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@

from __future__ import annotations

import asyncio
import functools
import typing as t

from plumpy import ProcessController
from plumpy.rmq import RemoteProcessThreadController

from aiida.brokers.broker import Broker
from aiida.brokers.rabbitmq.coordinator import RmqLoopCoordinator
from aiida.common.log import AIIDA_LOGGER
from aiida.manage.configuration import get_config_option

Expand All @@ -24,14 +29,16 @@
class RabbitmqBroker(Broker):
"""Implementation of the message broker interface using RabbitMQ through ``kiwipy``."""

def __init__(self, profile: Profile) -> None:
def __init__(self, profile: Profile, loop=None) -> None:
"""Construct a new instance.

:param profile: The profile.
"""
self._profile = profile
self._communicator: 'RmqThreadCommunicator' | None = None
self._communicator: 'RmqThreadCommunicator | None' = None
self._prefix = f'aiida-{self._profile.uuid}'
self._coordinator = None
self._loop = loop or asyncio.get_event_loop()

def __str__(self):
try:
Expand All @@ -47,24 +54,36 @@ def close(self):

def iterate_tasks(self):
"""Return an iterator over the tasks in the launch queue."""
for task in self.get_communicator().task_queue(get_launch_queue_name(self._prefix)):
for task in self.get_coordinator().communicator.task_queue(get_launch_queue_name(self._prefix)):
yield task

def get_communicator(self) -> 'RmqThreadCommunicator':
def get_coordinator(self):
if self._coordinator is not None:
return self._coordinator

return self.create_coordinator()

def create_coordinator(self):
if self._communicator is None:
self._communicator = self._create_communicator()
# Check whether a compatible version of RabbitMQ is being used.
self.check_rabbitmq_version()

return self._communicator
coordinator = RmqLoopCoordinator(self._communicator, self._loop)

return coordinator

def get_controller(self) -> ProcessController:
coordinator = self.get_coordinator()
return RemoteProcessThreadController(coordinator)

def _create_communicator(self) -> 'RmqThreadCommunicator':
"""Return an instance of :class:`kiwipy.Communicator`."""
from kiwipy.rmq import RmqThreadCommunicator

from aiida.orm.utils import serialize

self._communicator = RmqThreadCommunicator.connect(
_communicator = RmqThreadCommunicator.connect(
connection_params={'url': self.get_url()},
message_exchange=get_message_exchange_name(self._prefix),
encoder=functools.partial(serialize.serialize, encoding='utf-8'),
Expand All @@ -78,7 +97,7 @@ def _create_communicator(self) -> 'RmqThreadCommunicator':
testing_mode=self._profile.is_test_profile,
)

return self._communicator
return _communicator

def check_rabbitmq_version(self):
"""Check the version of RabbitMQ that is being connected to and emit warning if it is not compatible."""
Expand Down Expand Up @@ -122,4 +141,4 @@ def get_rabbitmq_version(self):
"""
from packaging.version import parse

return parse(self.get_communicator().server_properties['version'])
return parse(self.get_coordinator().communicator.server_properties['version'])
93 changes: 93 additions & 0 deletions src/aiida/brokers/rabbitmq/coordinator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import concurrent.futures
from asyncio import AbstractEventLoop
from typing import Generic, TypeVar, final

import kiwipy
from plumpy.exceptions import CoordinatorConnectionError
from plumpy.rmq.communications import convert_to_comm

__all__ = ['RmqCoordinator']

U = TypeVar('U', bound=kiwipy.Communicator)


@final
class RmqLoopCoordinator(Generic[U]):
def __init__(self, comm: U, loop: AbstractEventLoop):
self._comm = comm
self._loop = loop

@property
def communicator(self) -> U:
"""The inner communicator."""
return self._comm

def add_rpc_subscriber(self, subscriber, identifier=None):
subscriber = convert_to_comm(subscriber, self._loop)
return self._comm.add_rpc_subscriber(subscriber, identifier)

def add_broadcast_subscriber(
self,
subscriber,
subject_filters=None,
sender_filters=None,
identifier=None,
):
# XXX: this change behavior of create_task when decide whether the broadcast is_filtered.
# Need to understand the BroadcastFilter and make the improvement.
# To manifest the issue of run_task not await, run twice 'test_launch.py::test_submit_wait'.

# subscriber = kiwipy.BroadcastFilter(subscriber)
#
# subject_filters = subject_filters or []
# sender_filters = sender_filters or []
#
# for filter in subject_filters:
# subscriber.add_subject_filter(filter)
# for filter in sender_filters:
# subscriber.add_sender_filter(filter)

subscriber = convert_to_comm(subscriber, self._loop)
return self._comm.add_broadcast_subscriber(subscriber, identifier)

def add_task_subscriber(self, subscriber, identifier=None):
subscriber = convert_to_comm(subscriber, self._loop)
return self._comm.add_task_subscriber(subscriber, identifier)

def remove_rpc_subscriber(self, identifier):
return self._comm.remove_rpc_subscriber(identifier)

def remove_broadcast_subscriber(self, identifier):
return self._comm.remove_broadcast_subscriber(identifier)

def remove_task_subscriber(self, identifier):
return self._comm.remove_task_subscriber(identifier)

def rpc_send(self, recipient_id, msg):
return self._comm.rpc_send(recipient_id, msg)

def broadcast_send(
self,
body,
sender=None,
subject=None,
correlation_id=None,
):
from aio_pika.exceptions import AMQPConnectionError, ChannelInvalidStateError

try:
rsp = self._comm.broadcast_send(body, sender, subject, correlation_id)
except (ChannelInvalidStateError, AMQPConnectionError, concurrent.futures.TimeoutError) as exc:
raise CoordinatorConnectionError from exc
else:
return rsp

def task_send(self, task, no_reply=False):
return self._comm.task_send(task, no_reply)

def close(self):
self._comm.close()

def is_closed(self) -> bool:
"""Return `True` if the communicator was closed"""
return self._comm.is_closed()
19 changes: 10 additions & 9 deletions src/aiida/cmdline/commands/cmd_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import click

from aiida.brokers.broker import Broker
from aiida.cmdline.commands.cmd_verdi import verdi
from aiida.cmdline.params import arguments, options, types
from aiida.cmdline.utils import decorators, echo
Expand Down Expand Up @@ -340,8 +341,8 @@ def process_kill(processes, all_entries, timeout, wait):

with capture_logging() as stream:
try:
message = 'Killed through `verdi process kill`'
control.kill_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, message=message)
msg_text = 'Killed through `verdi process kill`'
control.kill_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, msg_text=msg_text)
except control.ProcessTimeoutException as exception:
echo.echo_critical(f'{exception}\n{REPAIR_INSTRUCTIONS}')

Expand Down Expand Up @@ -371,8 +372,8 @@ def process_pause(processes, all_entries, timeout, wait):

with capture_logging() as stream:
try:
message = 'Paused through `verdi process pause`'
control.pause_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, message=message)
msg_text = 'Paused through `verdi process pause`'
control.pause_processes(processes, all_entries=all_entries, timeout=timeout, wait=wait, msg_text=msg_text)
except control.ProcessTimeoutException as exception:
echo.echo_critical(f'{exception}\n{REPAIR_INSTRUCTIONS}')

Expand Down Expand Up @@ -416,7 +417,7 @@ def process_play(processes, all_entries, timeout, wait):
@decorators.with_dbenv()
@decorators.with_broker
@decorators.only_if_daemon_running(echo.echo_warning, 'daemon is not running, so process may not be reachable')
def process_watch(broker, processes, most_recent_node):
def process_watch(broker: Broker, processes, most_recent_node):
"""Watch the state transitions of processes.

Watch the state transitions for one or multiple running processes."""
Expand All @@ -436,7 +437,7 @@ def process_watch(broker, processes, most_recent_node):

from kiwipy import BroadcastFilter

def _print(communicator, body, sender, subject, correlation_id):
def _print(coordinator, body, sender, subject, correlation_id):
"""Format the incoming broadcast data into a message and echo it to stdout."""
if body is None:
body = 'No message specified'
Expand All @@ -446,7 +447,7 @@ def _print(communicator, body, sender, subject, correlation_id):

echo.echo(f'Process<{sender}> [{subject}|{correlation_id}]: {body}')

communicator = broker.get_communicator()
coordinator = broker.get_coordinator()
echo.echo_report('watching for broadcasted messages, press CTRL+C to stop...')

if most_recent_node:
Expand All @@ -457,7 +458,7 @@ def _print(communicator, body, sender, subject, correlation_id):
echo.echo_error(f'Process<{process.pk}> is already terminated')
continue

communicator.add_broadcast_subscriber(BroadcastFilter(_print, sender=process.pk))
coordinator.add_broadcast_subscriber(BroadcastFilter(_print, sender=process.pk))

try:
# Block this thread indefinitely until interrupt
Expand All @@ -467,7 +468,7 @@ def _print(communicator, body, sender, subject, correlation_id):
echo.echo('') # add a new line after the interrupt character
echo.echo_report('received interrupt, exiting...')
try:
communicator.close()
coordinator.close()
except RuntimeError:
pass

Expand Down
6 changes: 4 additions & 2 deletions src/aiida/cmdline/commands/cmd_rabbitmq.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from aiida.cmdline.commands.cmd_devel import verdi_devel
from aiida.cmdline.params import arguments, options
from aiida.cmdline.utils import decorators, echo, echo_tabulate
from aiida.manage.manager import Manager

if t.TYPE_CHECKING:
import requests
Expand Down Expand Up @@ -131,12 +132,13 @@ def with_client(ctx, wrapped, _, args, kwargs):

@cmd_rabbitmq.command('server-properties')
@decorators.with_manager
def cmd_server_properties(manager):
def cmd_server_properties(manager: Manager):
"""List the server properties."""
import yaml

data = {}
for key, value in manager.get_communicator().server_properties.items():
# FIXME: server_properties as an common API for coordinator?
for key, value in manager.get_coordinator().communicator.server_properties.items():
data[key] = value.decode('utf-8') if isinstance(value, bytes) else value
click.echo(yaml.dump(data, indent=4))

Expand Down
2 changes: 1 addition & 1 deletion src/aiida/cmdline/commands/cmd_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def verdi_status(print_traceback, no_rmq):

if broker:
try:
broker.get_communicator()
broker.get_coordinator()
except Exception as exc:
message = f'Unable to connect to broker: {broker}'
print_status(ServiceStatus.ERROR, 'broker', message, exception=exc, print_traceback=print_traceback)
Expand Down
Loading
Loading