From 5acae396a2dddc1b35b8ff342d73777c13976f8b Mon Sep 17 00:00:00 2001 From: xjules Date: Mon, 11 Nov 2024 15:30:45 +0100 Subject: [PATCH] Implementing router-dealer pattern with custom acknowledgments with zmq - dispatcher now send messages in chunks - dispatcher always for acknolwedgment from the evaluator - removing websockets, no more wait_for_evaluator - Settup encryption with curve - each dealer (client, dispatcher) will get a unique name --- src/_ert/forward_model_runner/client.py | 152 ++++++------ .../forward_model_runner/reporting/event.py | 39 +-- src/ert/ensemble_evaluator/_ensemble.py | 20 +- .../ensemble_evaluator/_wait_for_evaluator.py | 7 +- src/ert/ensemble_evaluator/config.py | 26 +- src/ert/ensemble_evaluator/evaluator.py | 227 ++++++++---------- .../evaluator_connection_info.py | 14 +- src/ert/ensemble_evaluator/monitor.py | 80 +++--- src/ert/run_models/base_run_model.py | 8 +- src/ert/shared/net_utils.py | 1 + 10 files changed, 254 insertions(+), 320 deletions(-) diff --git a/src/_ert/forward_model_runner/client.py b/src/_ert/forward_model_runner/client.py index 60b1042ab91..153fb0bf9e5 100644 --- a/src/_ert/forward_model_runner/client.py +++ b/src/_ert/forward_model_runner/client.py @@ -1,17 +1,13 @@ +from __future__ import annotations + import asyncio import logging -import ssl -from typing import Any, AnyStr, Optional, Union +import uuid +from typing import Any, Optional, Union +import zmq +import zmq.asyncio from typing_extensions import Self -from websockets.asyncio.client import ClientConnection, connect -from websockets.datastructures import Headers -from websockets.exceptions import ( - ConnectionClosedError, - ConnectionClosedOK, - InvalidHandshake, - InvalidURI, -) from _ert.async_utils import new_event_loop @@ -35,18 +31,18 @@ def __enter__(self) -> Self: return self def __exit__(self, exc_type: Any, exc_value: Any, exc_traceback: Any) -> None: - if self.websocket is not None: - self.loop.run_until_complete(self.websocket.close()) - self.loop.close() + self.socket.close() + self.context.term() - async def __aenter__(self) -> "Client": + async def __aenter__(self) -> Self: return self async def __aexit__( self, exc_type: Any, exc_value: Any, exc_traceback: Any ) -> None: - if self.websocket is not None: - await self.websocket.close() + self.socket.close() + self.context.term() + self.loop.close() def __init__( self, @@ -55,84 +51,80 @@ def __init__( cert: Optional[Union[str, bytes]] = None, max_retries: Optional[int] = None, timeout_multiplier: Optional[int] = None, + dealer_name: str | None = None, ) -> None: if max_retries is None: max_retries = self.DEFAULT_MAX_RETRIES if timeout_multiplier is None: timeout_multiplier = self.DEFAULT_TIMEOUT_MULTIPLIER - if url is None: - raise ValueError("url was None") self.url = url self.token = token - self._additional_headers = Headers() + + # Set up ZeroMQ context and socket + self.context = zmq.asyncio.Context() # type: ignore + self.socket = self.context.socket(zmq.DEALER) + if dealer_name is None: + dispatch_id = f"dispatch-{uuid.uuid4().hex[:8]}" + else: + dispatch_id = dealer_name + self.socket.setsockopt_string(zmq.IDENTITY, dispatch_id) if token is not None: - self._additional_headers["token"] = token - - # Mimics the behavior of the ssl argument when connection to - # websockets. If none is specified it will deduce based on the url, - # if True it will enforce TLS, and if you want to use self signed - # certificates you need to pass an ssl_context with the certificate - # loaded. - self._ssl_context: Optional[Union[bool, ssl.SSLContext]] = None - if cert is not None: - self._ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - self._ssl_context.load_verify_locations(cadata=cert) - elif url.startswith("wss"): - self._ssl_context = True + client_public, client_secret = zmq.curve_keypair() + self.socket.curve_secretkey = client_secret + self.socket.curve_publickey = client_public + self.socket.curve_serverkey = token.encode("utf-8") + self.socket.connect(url) self._max_retries = max_retries self._timeout_multiplier = timeout_multiplier - self.websocket: Optional[ClientConnection] = None self.loop = new_event_loop() - async def get_websocket(self) -> ClientConnection: - return await connect( - self.url, - ssl=self._ssl_context, - additional_headers=self._additional_headers, - open_timeout=self.CONNECTION_TIMEOUT, - ping_timeout=self.CONNECTION_TIMEOUT, - ping_interval=self.CONNECTION_TIMEOUT, - close_timeout=self.CONNECTION_TIMEOUT, - ) - - async def _send(self, msg: AnyStr) -> None: - for retry in range(self._max_retries + 1): + async def reconnect(self): + """Connect to the server with exponential backoff.""" + retries = self._max_retries + while retries > 0: try: - if self.websocket is None: - self.websocket = await self.get_websocket() - await self.websocket.send(msg) - return - except ConnectionClosedOK as exception: - _error_msg = ( - f"Connection closed received from the server {self.url}! " - f" Exception from {type(exception)}: {exception!s}" + self.socket.connect(self.url) + break + except zmq.ZMQError as e: + logger.warning(f"Failed to connect to {self.url}: {e}") + retries -= 1 + if retries == 0: + raise e + # Exponential backoff + sleep_time = self._timeout_multiplier * (self._max_retries - retries) + await asyncio.sleep(sleep_time) + + def send(self, messages: str | list[str]) -> None: + self.loop.run_until_complete(self.send_async(messages)) + + async def send_async(self, messages: str | list[str]) -> None: + if isinstance(messages, str): + messages = [messages] + retries = 0 + max_retries = 5 + while retries < max_retries: + try: + logger.debug(f"sending messages: {messages}") + await self.socket.send_multipart( + [b""] + [message.encode("utf-8") for message in messages] ) - raise ClientConnectionClosedOK(_error_msg) from exception - except ( - InvalidHandshake, - InvalidURI, - OSError, - asyncio.TimeoutError, - ) as exception: - if retry == self._max_retries: - _error_msg = ( - f"Not able to establish the " - f"websocket connection {self.url}! Max retries reached!" - " Check for firewall issues." - f" Exception from {type(exception)}: {exception!s}" + try: + _, ack = await asyncio.wait_for( + self.socket.recv_multipart(), timeout=3 ) - raise ClientConnectionError(_error_msg) from exception - except ConnectionClosedError as exception: - if retry == self._max_retries: - _error_msg = ( - f"Not been able to send the event" - f" to {self.url}! Max retries reached!" - f" Exception from {type(exception)}: {exception!s}" + logger.debug(f"Got acknowledgment: {ack}") + if ack.decode() == "ACK": + break + logger.warning( + "Got acknowledgment but not the expected message. Resending" ) - raise ClientConnectionError(_error_msg) from exception - await asyncio.sleep(0.2 + self._timeout_multiplier * retry) - self.websocket = None - - def send(self, msg: AnyStr) -> None: - self.loop.run_until_complete(self._send(msg)) + retries += 1 + except asyncio.TimeoutError: + logger.warning( + "Failed to get acknowledgment on the message. Resending" + ) + retries += 1 + except zmq.ZMQError as e: + logger.warning(f"Failed to send message from {e} reconnecting ...") + await self.reconnect() diff --git a/src/_ert/forward_model_runner/reporting/event.py b/src/_ert/forward_model_runner/reporting/event.py index 8bf13dee238..11ef56374b6 100644 --- a/src/_ert/forward_model_runner/reporting/event.py +++ b/src/_ert/forward_model_runner/reporting/event.py @@ -3,6 +3,7 @@ import logging import queue import threading +import time from datetime import datetime, timedelta from pathlib import Path from typing import Final, Union @@ -18,8 +19,6 @@ ) from _ert.forward_model_runner.client import ( Client, - ClientConnectionClosedOK, - ClientConnectionError, ) from _ert.forward_model_runner.reporting.base import Reporter from _ert.forward_model_runner.reporting.message import ( @@ -90,7 +89,8 @@ def _event_publisher(self): token=self._token, cert=self._cert, ) as client: - event = None + events = [] + last_sent_time = time.time() while True: with self._timestamp_lock: if ( @@ -99,23 +99,28 @@ def _event_publisher(self): ): self._timeout_timestamp = None break - if event is None: - # if we successfully sent the event we can proceed - # to next one + + try: event = self._event_queue.get() + logger.debug(f"Got event for zmq: {event}") if event is self._sentinel: + if events: + logger.debug(f"Got event class for zmq: {events}") + client.send(events) + events.clear() break - try: - client.send(event_to_json(event)) - event = None - except ClientConnectionError as exception: - # Possible intermittent failure, we retry sending the event - logger.error(str(exception)) - except ClientConnectionClosedOK as exception: - # The receiving end has closed the connection, we stop - # sending events - logger.debug(str(exception)) - break + events.append(event_to_json(event)) + + current_time = time.time() + if current_time - last_sent_time >= 2: + if events: + logger.debug(f"Got event class for zmq: {events}") + client.send(events) + events.clear() + last_sent_time = current_time + except Exception as e: + logger.error(f"Failed to send event: {e}") + raise def report(self, msg): self._statemachine.transition(msg) diff --git a/src/ert/ensemble_evaluator/_ensemble.py b/src/ert/ensemble_evaluator/_ensemble.py index 877b40ad627..424b9c15fd1 100644 --- a/src/ert/ensemble_evaluator/_ensemble.py +++ b/src/ert/ensemble_evaluator/_ensemble.py @@ -31,13 +31,8 @@ from ert.run_arg import RunArg from ert.scheduler import Scheduler, create_driver -from ._wait_for_evaluator import wait_for_evaluator from .config import EvaluatorServerConfig -from .snapshot import ( - EnsembleSnapshot, - FMStepSnapshot, - RealizationSnapshot, -) +from .snapshot import EnsembleSnapshot, FMStepSnapshot, RealizationSnapshot from .state import ( ENSEMBLE_STATE_CANCELLED, ENSEMBLE_STATE_FAILED, @@ -122,6 +117,7 @@ def __post_init__(self) -> None: self._config: Optional[EvaluatorServerConfig] = None self.snapshot: EnsembleSnapshot = self._create_snapshot() self.status = self.snapshot.status + self._client: Client | None = None if self.snapshot.status: self._status_tracker = _EnsembleStateTracker(self.snapshot.status) else: @@ -208,7 +204,7 @@ async def send_event( retries: int = 10, ) -> None: async with Client(url, token, cert, max_retries=retries) as client: - await client._send(event_to_json(event)) + await client.send_async(event_to_json(event)) def generate_event_creator(self) -> Callable[[Id.ENSEMBLE_TYPES], Event]: def event_builder(status: str) -> Event: @@ -233,16 +229,12 @@ async def evaluate( ce_unary_send_method_name, partialmethod( self.__class__.send_event, - self._config.dispatch_uri, + self._config.get_connection_info().router_uri, token=self._config.token, cert=self._config.cert, ), ) - await wait_for_evaluator( - base_url=self._config.url, - token=self._config.token, - cert=self._config.cert, - ) + await self._evaluate_inner( event_unary_send=getattr(self, ce_unary_send_method_name), scheduler_queue=scheduler_queue, @@ -285,7 +277,7 @@ async def _evaluate_inner( # pylint: disable=too-many-branches max_running=self._queue_config.max_running, submit_sleep=self._queue_config.submit_sleep, ens_id=self.id_, - ee_uri=self._config.dispatch_uri, + ee_uri=self._config.get_connection_info().router_uri, ee_cert=self._config.cert, ee_token=self._config.token, ) diff --git a/src/ert/ensemble_evaluator/_wait_for_evaluator.py b/src/ert/ensemble_evaluator/_wait_for_evaluator.py index 9b5f5591292..c677875a55e 100644 --- a/src/ert/ensemble_evaluator/_wait_for_evaluator.py +++ b/src/ert/ensemble_evaluator/_wait_for_evaluator.py @@ -1,11 +1,7 @@ -import asyncio import logging import ssl -import time from typing import Optional, Union -import aiohttp - logger = logging.getLogger(__name__) WAIT_FOR_EVALUATOR_TIMEOUT = 60 @@ -17,6 +13,7 @@ def get_ssl_context(cert: Optional[Union[str, bytes]]) -> Union[ssl.SSLContext, ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.load_verify_locations(cadata=cert) return ssl_context +<<<<<<< HEAD async def attempt_connection( @@ -76,3 +73,5 @@ async def wait_for_evaluator( cert=cert, connection_timeout=connection_timeout, ) +======= +>>>>>>> feac78628 (Implementing router-dealer pattern with custom acknowledgments with zmq) diff --git a/src/ert/ensemble_evaluator/config.py b/src/ert/ensemble_evaluator/config.py index 79c127cccdb..d6a4a99bcd9 100644 --- a/src/ert/ensemble_evaluator/config.py +++ b/src/ert/ensemble_evaluator/config.py @@ -11,6 +11,7 @@ from datetime import datetime, timedelta from typing import Optional +import zmq from cryptography import x509 from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import hashes, serialization @@ -129,29 +130,38 @@ def __init__( custom_host: typing.Optional[str] = None, ) -> None: self._socket_handle = find_available_socket( - custom_range=custom_port_range, custom_host=custom_host + custom_range=custom_port_range, + custom_host=custom_host, + will_close_then_reopen_socket=True, ) host, port = self._socket_handle.getsockname() - self.protocol = "wss" if generate_cert else "ws" - self.url = f"{self.protocol}://{host}:{port}" - self.client_uri = f"{self.url}/client" - self.dispatch_uri = f"{self.url}/dispatch" + self.host = host + self.router_port = port + + self._socket_handle = find_available_socket( + custom_range=custom_port_range, + custom_host=custom_host, + will_close_then_reopen_socket=True, + ) + if generate_cert: cert, key, pw = _generate_certificate(host) + self.server_public_key, self.server_secret_key = zmq.curve_keypair() + self.token = self.server_public_key.decode("utf-8") else: cert, key, pw = None, None, None + self.server_public_key, self.server_secret_key = None, None + self.token = None self.cert = cert self._key: Optional[bytes] = key self._key_pw = pw - self.token = _generate_authentication() if use_token else None - def get_socket(self) -> socket.socket: return self._socket_handle.dup() def get_connection_info(self) -> EvaluatorConnectionInfo: return EvaluatorConnectionInfo( - self.url, + f"tcp://{self.host}:{self.router_port}", self.cert, self.token, ) diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index 104a830f5de..05c432bec4f 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -1,16 +1,14 @@ +from __future__ import annotations + import asyncio import datetime import logging import traceback -from contextlib import asynccontextmanager, contextmanager -from http import HTTPStatus from typing import ( Any, - AsyncIterator, Awaitable, Callable, Dict, - Generator, Iterable, List, Optional, @@ -22,10 +20,7 @@ get_args, ) -from pydantic_core._pydantic_core import ValidationError -from websockets.asyncio.server import ServerConnection, serve -from websockets.exceptions import ConnectionClosedError -from websockets.http11 import Request, Response +import zmq.asyncio from _ert.events import ( EESnapshot, @@ -69,15 +64,11 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig): self._loop: Optional[asyncio.AbstractEventLoop] = None - self._clients: Set[ServerConnection] = set() - self._dispatchers_connected: asyncio.Queue[None] = asyncio.Queue() - self._events: asyncio.Queue[Event] = asyncio.Queue() self._events_to_send: asyncio.Queue[Event] = asyncio.Queue() self._manifest_queue: asyncio.Queue[Any] = asyncio.Queue() self._ee_tasks: List[asyncio.Task[None]] = [] - self._server_started: asyncio.Event = asyncio.Event() self._server_done: asyncio.Event = asyncio.Event() # batching section @@ -87,14 +78,39 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig): self._max_batch_size: int = 500 self._batching_interval: float = 2.0 self._complete_batch: asyncio.Event = asyncio.Event() + self._zmq_context: zmq.asyncio.Context | None = None + self._clients_connected: set[bytes] = set() + self._clients_empty: asyncio.Event = asyncio.Event() + + async def _initialize_zmq(self) -> None: + self._zmq_context = zmq.asyncio.Context() # type: ignore + try: + # Create and configure the ROUTER socket + self._router_socket: zmq.asyncio.Socket = self._zmq_context.socket( + zmq.ROUTER + ) + self._router_socket.curve_secretkey = self._config.server_secret_key + self._router_socket.curve_publickey = self._config.server_public_key + self._router_socket.curve_server = True + + # Attempt to bind the ROUTER socket + logger.info(f"Attempting to bind to tcp://*:{self._config.router_port}") + self._router_socket.bind(f"tcp://*:{self._config.router_port}") + logger.info(f"Successfully bound to tcp://*:{self._config.router_port}") + + except zmq.error.ZMQError as e: + logger.error(f"ZMQ error during initialization: {e}") + raise + + logger.info("ZMQ initialized and ready to handle requests") async def _publisher(self) -> None: while True: event = await self._events_to_send.get() - await asyncio.gather( - *[client.send(event_to_json(event)) for client in self._clients], - return_exceptions=True, - ) + for identity in self._clients_connected: + await self._router_socket.send_multipart( + [identity, b"", event_to_json(event).encode()] + ) self._events_to_send.task_done() async def _append_message(self, snapshot_update_event: EnsembleSnapshot) -> None: @@ -205,43 +221,41 @@ async def _failed_handler(self, events: Sequence[EnsembleFailed]) -> None: def ensemble(self) -> Ensemble: return self._ensemble - @contextmanager - def store_client(self, websocket: ServerConnection) -> Generator[None, None, None]: - self._clients.add(websocket) - yield - self._clients.remove(websocket) - - async def handle_client(self, websocket: ServerConnection) -> None: - with self.store_client(websocket): - current_snapshot_dict = self._ensemble.snapshot.to_dict() - event: Event = EESnapshot( - snapshot=current_snapshot_dict, ensemble=self.ensemble.id_ - ) - await websocket.send(event_to_json(event)) - - async for raw_msg in websocket: - event = event_from_json(raw_msg) - logger.debug(f"got message from client: {event}") - if type(event) is EEUserCancel: - logger.debug(f"Client {websocket.remote_address} asked to cancel.") - self._signal_cancel() - - elif type(event) is EEUserDone: - logger.debug(f"Client {websocket.remote_address} signalled done.") - self.stop() - - @asynccontextmanager - async def count_dispatcher(self) -> AsyncIterator[None]: - await self._dispatchers_connected.put(None) - yield - await self._dispatchers_connected.get() - self._dispatchers_connected.task_done() - - async def handle_dispatch(self, websocket: ServerConnection) -> None: - async with self.count_dispatcher(): + async def listen_for_messages(self) -> None: + while True: try: - async for raw_msg in websocket: - try: + dealer, _, *frames = await self._router_socket.recv_multipart() + sender = dealer.decode("utf-8") + if sender.startswith("client"): + for frame in frames: + raw_msg = frame.decode("utf-8") + if raw_msg == "CONNECT": + self._clients_connected.add(dealer) + self._clients_empty.clear() + current_snapshot_dict = self._ensemble.snapshot.to_dict() + event: Event = EESnapshot( + snapshot=current_snapshot_dict, + ensemble=self.ensemble.id_, + ) + await self._router_socket.send_multipart( + [dealer, b"", event_to_json(event).encode()] + ) + elif raw_msg == "DISCONNECT": + self._clients_connected.remove(dealer) + if not self._clients_connected: + self._clients_empty.set() + else: + event = event_from_json(raw_msg) + if type(event) is EEUserCancel: + logger.debug("Client asked to cancel.") + self._signal_cancel() + elif type(event) is EEUserDone: + logger.debug("Client signalled done.") + self.stop() + elif sender.startswith("dispatch"): + await self._router_socket.send_multipart([dealer, b"", b"ACK"]) + for frame in frames: + raw_msg = frame.decode("utf-8") event = dispatch_event_from_json(raw_msg) if event.ensemble != self.ensemble.id_: logger.info( @@ -254,90 +268,35 @@ async def handle_dispatch(self, websocket: ServerConnection) -> None: await self.forward_checksum(event) else: await self._events.put(event) - except ValidationError as ex: - logger.warning( - "cannot handle event - " - f"closing connection to dispatcher: {ex}" - ) - await websocket.close( - code=1011, reason=f"failed handling message {raw_msg!r}" - ) - return - - if type(event) in [EnsembleSucceeded, EnsembleFailed]: - return - except ConnectionClosedError as connection_error: - # Dispatchers may close the connection abruptly in the case of - # * flaky network (then the dispatcher will try to reconnect) - # * job being killed due to MAX_RUNTIME - # * job being killed by user - logger.error( - f"a dispatcher abruptly closed a websocket: {connection_error!s}" - ) + # if type(event) in [EnsembleSucceeded, EnsembleFailed]: + # return + else: + logger.info(f"Connection attempt to unknown sender: {sender}.") + except zmq.error.ZMQError as e: + if e.errno == zmq.ENOTSOCK: + logger.warning( + "Evaluator receiver closed, no new messages are received" + ) + else: + logger.error(f"Unexpected error when listening to messages: {e}") + except asyncio.CancelledError: + return async def forward_checksum(self, event: Event) -> None: # clients still need to receive events via ws await self._events_to_send.put(event) await self._manifest_queue.put(event) - async def connection_handler(self, websocket: ServerConnection) -> None: - if websocket.request is not None: - path = websocket.request.path - elements = path.split("/") - if elements[1] == "client": - await self.handle_client(websocket) - elif elements[1] == "dispatch": - await self.handle_dispatch(websocket) - else: - logger.info(f"Connection attempt to unknown path: {path}.") - else: - logger.info("No request to handle.") - - async def process_request( - self, connection: ServerConnection, request: Request - ) -> Optional[Response]: - if request.headers.get("token") != self._config.token: - return connection.respond(HTTPStatus.UNAUTHORIZED, "") - if request.path == "/healthcheck": - return connection.respond(HTTPStatus.OK, "") - return None - async def _server(self) -> None: - async with serve( - self.connection_handler, - sock=self._config.get_socket(), - ssl=self._config.get_server_ssl_context(), - process_request=self.process_request, - max_size=2**26, - ping_timeout=60, - ping_interval=60, - close_timeout=60, - ) as server: - self._server_started.set() - await self._server_done.wait() - server.close(close_connections=False) - if self._dispatchers_connected is not None: - logger.debug( - f"Got done signal. {self._dispatchers_connected.qsize()} " - "dispatchers to disconnect..." - ) - try: # Wait for dispatchers to disconnect - await asyncio.wait_for( - self._dispatchers_connected.join(), timeout=20 - ) - except asyncio.TimeoutError: - logger.debug("Timed out waiting for dispatchers to disconnect") - else: - logger.debug("Got done signal. No dispatchers connected") - - logger.debug("Sending termination-message to clients...") - - await self._events.join() - await self._complete_batch.wait() - await self._batch_processing_queue.join() - event = EETerminated(ensemble=self._ensemble.id_) - await self._events_to_send.put(event) - await self._events_to_send.join() + await self._server_done.wait() + await self._events.join() + await self._complete_batch.wait() + await self._batch_processing_queue.join() + event = EETerminated(ensemble=self._ensemble.id_) + await self._events_to_send.put(event) + await self._events_to_send.join() + await self._clients_empty.wait() + self._router_socket.close() logger.debug("Async server exiting.") def stop(self) -> None: @@ -364,6 +323,7 @@ async def _start_running(self) -> None: if not self._config: raise ValueError("no config for evaluator") self._loop = asyncio.get_running_loop() + await self._initialize_zmq() self._ee_tasks = [ asyncio.create_task(self._server(), name="server_task"), asyncio.create_task( @@ -371,10 +331,8 @@ async def _start_running(self) -> None: ), asyncio.create_task(self._process_event_buffer(), name="processing_task"), asyncio.create_task(self._publisher(), name="publisher_task"), + asyncio.create_task(self.listen_for_messages(), name="listener_task"), ] - # now we wait for the server to actually start - await self._server_started.wait() - self._ee_tasks.append( asyncio.create_task( self._ensemble.evaluate( @@ -425,7 +383,10 @@ async def _monitor_and_handle_tasks(self) -> None: if stop_timeout_task: stop_timeout_task.cancel() return - elif task.get_name() == "ensemble_task": + elif task.get_name() in [ + "ensemble_task", + "listener_task", + ]: stop_timeout_task = asyncio.create_task( self._wait_for_stopped_server() ) diff --git a/src/ert/ensemble_evaluator/evaluator_connection_info.py b/src/ert/ensemble_evaluator/evaluator_connection_info.py index bd48e08e4a1..1bd5f3ac1bb 100644 --- a/src/ert/ensemble_evaluator/evaluator_connection_info.py +++ b/src/ert/ensemble_evaluator/evaluator_connection_info.py @@ -6,18 +6,6 @@ class EvaluatorConnectionInfo: """Read only server-info""" - url: str + router_uri: str cert: Optional[Union[str, bytes]] = None token: Optional[str] = None - - @property - def dispatch_uri(self) -> str: - return f"{self.url}/dispatch" - - @property - def client_uri(self) -> str: - return f"{self.url}/client" - - @property - def result_uri(self) -> str: - return f"{self.url}/result" diff --git a/src/ert/ensemble_evaluator/monitor.py b/src/ert/ensemble_evaluator/monitor.py index 55449bc620b..6bb7765afc7 100644 --- a/src/ert/ensemble_evaluator/monitor.py +++ b/src/ert/ensemble_evaluator/monitor.py @@ -1,12 +1,12 @@ +from __future__ import annotations + import asyncio import logging import ssl import uuid from typing import TYPE_CHECKING, Any, AsyncGenerator, Final, Optional, Union -from aiohttp import ClientError -from websockets import ConnectionClosed, Headers -from websockets.asyncio.client import ClientConnection, connect +import zmq.asyncio from _ert.events import ( EETerminated, @@ -16,7 +16,6 @@ event_from_json, event_to_json, ) -from ert.ensemble_evaluator._wait_for_evaluator import wait_for_evaluator if TYPE_CHECKING: from ert.ensemble_evaluator.evaluator_connection_info import EvaluatorConnectionInfo @@ -36,11 +35,11 @@ def __init__(self, ee_con_info: "EvaluatorConnectionInfo") -> None: self._ee_con_info = ee_con_info self._id = str(uuid.uuid1()).split("-", maxsplit=1)[0] self._event_queue: asyncio.Queue[Union[Event, EventSentinel]] = asyncio.Queue() - self._connection: Optional[ClientConnection] = None self._receiver_task: Optional[asyncio.Task[None]] = None self._connected: asyncio.Event = asyncio.Event() self._connection_timeout: float = 120.0 self._receiver_timeout: float = 60.0 + self._zmq_context = zmq.asyncio.Context() # type: ignore async def __aenter__(self) -> "Monitor": self._receiver_task = asyncio.create_task(self._receiver()) @@ -57,6 +56,9 @@ async def __aenter__(self) -> "Monitor": async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: if self._receiver_task: + if self._socket: + await self._socket.send_multipart([b"", b"DISCONNECT"]) + self._socket.close() if not self._receiver_task.done(): self._receiver_task.cancel() # we are done and not interested in errors when cancelling @@ -65,27 +67,24 @@ async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None return_exceptions=True, ) - if self._connection: - await self._connection.close() - async def signal_cancel(self) -> None: - if not self._connection: - return await self._event_queue.put(Monitor._sentinel) logger.debug(f"monitor-{self._id} asking server to cancel...") cancel_event = EEUserCancel(monitor=self._id) - await self._connection.send(event_to_json(cancel_event)) + await self._socket.send_multipart( + [b"", event_to_json(cancel_event).encode("utf-8")] + ) logger.debug(f"monitor-{self._id} asked server to cancel") async def signal_done(self) -> None: - if not self._connection: - return await self._event_queue.put(Monitor._sentinel) logger.debug(f"monitor-{self._id} informing server monitor is done...") done_event = EEUserDone(monitor=self._id) - await self._connection.send(event_to_json(done_event)) + await self._socket.send_multipart( + [b"", event_to_json(done_event).encode("utf-8")] + ) logger.debug(f"monitor-{self._id} informed server monitor is done") async def track( @@ -124,36 +123,29 @@ async def _receiver(self) -> None: if self._ee_con_info.cert: tls = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) tls.load_verify_locations(cadata=self._ee_con_info.cert) - headers = Headers() - if self._ee_con_info.token: - headers["token"] = self._ee_con_info.token - - await wait_for_evaluator( - base_url=self._ee_con_info.url, - token=self._ee_con_info.token, - cert=self._ee_con_info.cert, - timeout=5, - ) - async for conn in connect( - self._ee_con_info.client_uri, - ssl=tls, - additional_headers=headers, - max_size=2**26, - max_queue=500, - open_timeout=5, - ping_timeout=60, - ping_interval=60, - close_timeout=60, - ): + + self._socket = self._zmq_context.socket(zmq.DEALER) + + if self._ee_con_info.token is not None: + client_public, client_secret = zmq.curve_keypair() + self._socket.curve_secretkey = client_secret + self._socket.curve_publickey = client_public + self._socket.curve_serverkey = self._ee_con_info.token.encode("utf-8") + + client_id = f"client-{uuid.uuid4().hex[:8]}" + self._socket.setsockopt_string(zmq.IDENTITY, client_id) + self._socket.connect(self._ee_con_info.router_uri) + await self._socket.send_multipart([b"", b"CONNECT"]) + self._connected.set() + + while True: try: - self._connection = conn - self._connected.set() - async for raw_msg in self._connection: - event = event_from_json(raw_msg) - await self._event_queue.put(event) - except (ConnectionRefusedError, ConnectionClosed, ClientError) as exc: - self._connection = None - self._connected.clear() + _, raw_msg = await self._socket.recv_multipart() + event = event_from_json(raw_msg.decode("utf-8")) + await self._event_queue.put(event) + except zmq.ZMQError as exc: + # Handle disconnection or other ZMQ errors (reconnect or log) logger.debug( - f"Monitor connection to EnsembleEvaluator went down, reconnecting: {exc}" + f"ZeroMQ connection to EnsembleEvaluator went down, reconnecting: {exc}" ) + await asyncio.sleep(1) diff --git a/src/ert/run_models/base_run_model.py b/src/ert/run_models/base_run_model.py index 53bc3a9fa9b..06af8006157 100644 --- a/src/ert/run_models/base_run_model.py +++ b/src/ert/run_models/base_run_model.py @@ -26,12 +26,7 @@ import numpy as np -from _ert.events import ( - EESnapshot, - EESnapshotUpdate, - EETerminated, - Event, -) +from _ert.events import EESnapshot, EESnapshotUpdate, EETerminated, Event from ert.analysis import ( AnalysisEvent, AnalysisStatusEvent, @@ -509,7 +504,6 @@ async def run_monitor( event, iteration, ) - if event.snapshot.get(STATUS) in [ ENSEMBLE_STATE_STOPPED, ENSEMBLE_STATE_FAILED, diff --git a/src/ert/shared/net_utils.py b/src/ert/shared/net_utils.py index 66c12aef6c9..2cf467481ac 100644 --- a/src/ert/shared/net_utils.py +++ b/src/ert/shared/net_utils.py @@ -111,6 +111,7 @@ def _bind_socket( if will_close_then_reopen_socket: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) else: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 0)