From 3668bb1b91780466f8baf57457926752c50566b9 Mon Sep 17 00:00:00 2001 From: Furkan Date: Sat, 19 Oct 2024 19:40:26 +0300 Subject: [PATCH] refactor: no longer reference DAQJob in messages to make it serializable - we now also use `dataclasses_json` for serialization, instead of `pickle` --- src/daq/alert/alert_slack.py | 2 +- src/daq/alert/base.py | 4 ++-- src/daq/base.py | 23 ++++++++++++++++++++++- src/daq/jobs/caen/n1081b.py | 2 +- src/daq/jobs/handle_alerts.py | 4 ++-- src/daq/jobs/handle_stats.py | 2 +- src/daq/jobs/healthcheck.py | 2 +- src/daq/jobs/remote.py | 30 ++++++++++++++++++++++++++---- src/daq/models.py | 2 +- src/daq/store/models.py | 15 +++------------ src/tests/test_handle_alerts.py | 4 ++-- src/tests/test_main.py | 2 +- src/tests/test_remote.py | 23 ++++++++++++++++------- src/tests/test_slack.py | 8 ++++---- 14 files changed, 83 insertions(+), 40 deletions(-) diff --git a/src/daq/alert/alert_slack.py b/src/daq/alert/alert_slack.py index e1e5554..1c76e19 100644 --- a/src/daq/alert/alert_slack.py +++ b/src/daq/alert/alert_slack.py @@ -39,7 +39,7 @@ def send_webhook(self, alert: DAQJobMessageAlert): { "fallback": alert.alert_info.message, "color": ALERT_SEVERITY_TO_SLACK_COLOR[alert.alert_info.severity], - "author_name": type(alert.daq_job).__name__, + "author_name": alert.daq_job_info.daq_job_class_name, "title": "Alert!", "fields": [ { diff --git a/src/daq/alert/base.py b/src/daq/alert/base.py index 233a720..0427291 100644 --- a/src/daq/alert/base.py +++ b/src/daq/alert/base.py @@ -6,7 +6,7 @@ from dataclasses_json import DataClassJsonMixin -from daq.base import DAQJob +from daq.base import DAQJob, DAQJobInfo from daq.models import DAQJobMessage @@ -24,7 +24,7 @@ class DAQAlertInfo(DataClassJsonMixin): @dataclass class DAQJobMessageAlert(DAQJobMessage): - daq_job: DAQJob + daq_job_info: DAQJobInfo date: datetime alert_info: DAQAlertInfo diff --git a/src/daq/base.py b/src/daq/base.py index 00f7818..f7391e0 100644 --- a/src/daq/base.py +++ b/src/daq/base.py @@ -1,10 +1,11 @@ import logging import threading +import uuid from dataclasses import dataclass from queue import Empty, Queue from typing import Any -from daq.models import DAQJobMessage, DAQJobMessageStop, DAQJobStopError +from daq.models import DAQJobConfig, DAQJobMessage, DAQJobMessageStop, DAQJobStopError daq_job_instance_id = 0 daq_job_instance_id_lock = threading.Lock() @@ -17,6 +18,7 @@ class DAQJob: message_in: Queue[DAQJobMessage] message_out: Queue[DAQJobMessage] instance_id: int + unique_id: str _logger: logging.Logger @@ -33,6 +35,7 @@ def __init__(self, config: Any): self.message_out = Queue() self._should_stop = False + self.unique_id = str(uuid.uuid4()) def consume(self): # consume messages from the queue @@ -61,6 +64,16 @@ def handle_message(self, message: "DAQJobMessage") -> bool: def start(self): raise NotImplementedError + def get_info(self) -> "DAQJobInfo": + return DAQJobInfo( + daq_job_type=self.config.daq_job_type + if isinstance(self.config, DAQJobConfig) + else self.config["daq_job_type"], + daq_job_class_name=type(self).__name__, + unique_id=self.unique_id, + instance_id=self.instance_id, + ) + def __del__(self): self._logger.info("DAQ job is being deleted") @@ -69,3 +82,11 @@ def __del__(self): class DAQJobThread: daq_job: DAQJob thread: threading.Thread + + +@dataclass +class DAQJobInfo: + daq_job_type: str + daq_job_class_name: str # has type(self).__name__ + unique_id: str + instance_id: int diff --git a/src/daq/jobs/caen/n1081b.py b/src/daq/jobs/caen/n1081b.py index fa1cde9..141dcc9 100644 --- a/src/daq/jobs/caen/n1081b.py +++ b/src/daq/jobs/caen/n1081b.py @@ -105,7 +105,7 @@ def _send_store_message(self, data: dict, section): self.message_out.put( DAQJobMessageStore( store_config=self.config.store_config, - daq_job=self, + daq_job_info=self.get_info(), prefix=section, keys=keys, data=[values], diff --git a/src/daq/jobs/handle_alerts.py b/src/daq/jobs/handle_alerts.py index 1e503c2..8f6b287 100644 --- a/src/daq/jobs/handle_alerts.py +++ b/src/daq/jobs/handle_alerts.py @@ -35,7 +35,7 @@ def handle_message(self, message: DAQJobMessageAlert) -> bool: data_to_send = [ [ get_unix_timestamp_ms(message.date), - type(message.daq_job).__name__, + message.daq_job_info.daq_job_class_name, message.alert_info.severity, message.alert_info.message, ] @@ -44,7 +44,7 @@ def handle_message(self, message: DAQJobMessageAlert) -> bool: self.message_out.put( DAQJobMessageStore( store_config=self.config.store_config, - daq_job=self, + daq_job_info=self.get_info(), keys=keys, data=data_to_send, ) diff --git a/src/daq/jobs/handle_stats.py b/src/daq/jobs/handle_stats.py index fed41d5..0dd3e19 100644 --- a/src/daq/jobs/handle_stats.py +++ b/src/daq/jobs/handle_stats.py @@ -70,7 +70,7 @@ def unpack_record(record: DAQJobStatsRecord): self.message_out.put( DAQJobMessageStore( store_config=self.config.store_config, - daq_job=self, + daq_job_info=self.get_info(), keys=keys, data=data_to_send, ) diff --git a/src/daq/jobs/healthcheck.py b/src/daq/jobs/healthcheck.py index 6d12067..e1786da 100644 --- a/src/daq/jobs/healthcheck.py +++ b/src/daq/jobs/healthcheck.py @@ -167,7 +167,7 @@ def handle_checks(self): def send_alert(self, item: HealthcheckItem): self.message_out.put( DAQJobMessageAlert( - daq_job=self, + daq_job_info=self.get_info(), date=datetime.now(), alert_info=item.alert_info, ) diff --git a/src/daq/jobs/remote.py b/src/daq/jobs/remote.py index e7314c0..928cb3d 100644 --- a/src/daq/jobs/remote.py +++ b/src/daq/jobs/remote.py @@ -1,4 +1,4 @@ -import pickle +import json import threading import time from dataclasses import dataclass @@ -27,6 +27,9 @@ class DAQJobRemote(DAQJob): allowed_message_in_types = [DAQJobMessage] # accept all message types config_type = DAQJobRemoteConfig config: DAQJobRemoteConfig + _zmq_local: zmq.Socket + _zmq_remote: zmq.Socket + _message_class_cache: dict def __init__(self, config: DAQJobRemoteConfig): super().__init__(config) @@ -35,21 +38,24 @@ def __init__(self, config: DAQJobRemoteConfig): self._zmq_remote = self._zmq_context.socket(zmq.PULL) self._zmq_local.connect(config.zmq_local_url) self._zmq_remote.connect(config.zmq_remote_url) + self._message_class_cache = {} self._receive_thread = threading.Thread( target=self._start_receive_thread, daemon=True ) + self._message_class_cache = { + x.__name__: x for x in DAQJobMessage.__subclasses__() + } def handle_message(self, message: DAQJobMessage) -> bool: - print(type(message)) - self._zmq_local.send(pickle.dumps(message)) + self._zmq_local.send(self._pack_message(message)) return True def _start_receive_thread(self): while True: message = self._zmq_remote.recv() # remote message_in -> message_out - self.message_out.put(pickle.loads(message)) + self.message_out.put(self._unpack_message(message)) def start(self): self._receive_thread.start() @@ -60,3 +66,19 @@ def start(self): # message_in -> remote message_out self.consume() time.sleep(0.1) + + def _pack_message(self, message: DAQJobMessage) -> bytes: + message_type = type(message).__name__ + return json.dumps([message_type, message.to_dict()]).encode("utf-8") + + def _unpack_message(self, message: bytes) -> DAQJobMessage: + message_type, data = json.loads(message.decode("utf-8")) + if message_type in self._message_class_cache: + message_class = self._message_class_cache[message_type] + else: + message_class = globals()[message_type] + self._message_class_cache[message_type] = message_class + + if not issubclass(message_class, DAQJobMessage): + raise Exception(f"Invalid message type: {message_type}") + return message_class.from_dict(data) diff --git a/src/daq/models.py b/src/daq/models.py index d953798..6949dfa 100644 --- a/src/daq/models.py +++ b/src/daq/models.py @@ -11,7 +11,7 @@ class DAQJobConfig(DataClassJsonMixin): @dataclass -class DAQJobMessage: +class DAQJobMessage(DataClassJsonMixin): pass diff --git a/src/daq/store/models.py b/src/daq/store/models.py index 2ec634d..9907347 100644 --- a/src/daq/store/models.py +++ b/src/daq/store/models.py @@ -1,9 +1,9 @@ from dataclasses import dataclass -from typing import Any, Optional +from typing import Any from dataclasses_json import DataClassJsonMixin -from daq.base import DAQJob +from daq.base import DAQJobInfo from daq.models import DAQJobConfig, DAQJobMessage @@ -19,20 +19,11 @@ class DAQJobStoreConfig(DataClassJsonMixin): @dataclass class DAQJobMessageStore(DAQJobMessage): store_config: dict | DAQJobStoreConfig - daq_job: Optional[DAQJob] + daq_job_info: DAQJobInfo keys: list[str] data: list[list[Any]] prefix: str | None = None - def __getstate__(self): - state = self.__dict__.copy() - del state["daq_job"] - return state - - def __setstate__(self, state): - self.__dict__.update(state) - self.daq_job = None # type: ignore - @dataclass class StorableDAQJobConfig(DAQJobConfig): diff --git a/src/tests/test_handle_alerts.py b/src/tests/test_handle_alerts.py index 77050d6..aea6029 100644 --- a/src/tests/test_handle_alerts.py +++ b/src/tests/test_handle_alerts.py @@ -22,7 +22,7 @@ def test_handle_message(self, mock_get_unix_timestamp_ms): message = MagicMock(spec=DAQJobMessageAlert) message.date = date - message.daq_job = MagicMock() + message.daq_job_info = MagicMock() message.alert_info = MagicMock() message.alert_info.severity = "high" message.alert_info.message = "Test alert message" @@ -42,7 +42,7 @@ def test_handle_message(self, mock_get_unix_timestamp_ms): [ [ get_unix_timestamp_ms(date), - type(message.daq_job).__name__, + message.daq_job_info.daq_job_class_name, "high", "Test alert message", ] diff --git a/src/tests/test_main.py b/src/tests/test_main.py index ebfb84f..4c73517 100644 --- a/src/tests/test_main.py +++ b/src/tests/test_main.py @@ -89,7 +89,7 @@ def test_send_messages_to_daq_jobs(self, mock_parse_store_config): mock_thread.daq_job.allowed_message_in_types = [DAQJobMessageStore] mock_thread.daq_job.message_in = Queue() mock_message = DAQJobMessageStore( - store_config={}, keys=[], data=[], daq_job=MagicMock() + store_config={}, keys=[], data=[], daq_job_info=MagicMock() ) daq_job_threads = [mock_thread] daq_job_threads: list[DAQJobThread] = daq_job_threads diff --git a/src/tests/test_remote.py b/src/tests/test_remote.py index b6d3b6d..5784463 100644 --- a/src/tests/test_remote.py +++ b/src/tests/test_remote.py @@ -1,4 +1,3 @@ -import pickle import threading import time import unittest @@ -26,10 +25,15 @@ def setUp(self, MockZmqContext): def test_handle_message(self): message = DAQJobMessageStore( - store_config={}, data=[], keys=[], daq_job=DAQJobTest({}) + store_config={}, + data=[], + keys=[], + daq_job_info=DAQJobTest({"daq_job_type": "test"}).get_info(), ) self.daq_job_remote.handle_message(message) - self.mock_sender.send.assert_called_once_with(pickle.dumps(message)) + self.mock_sender.send.assert_called_once_with( + self.daq_job_remote._pack_message(message) + ) def test_start(self): mock_receive_thread = MagicMock() @@ -45,7 +49,12 @@ def stop_receive_thread(): self.daq_job_remote.start() def test_receive_thread(self): - message = DAQJobMessageStore(store_config={}, data=[], keys=[], daq_job=None) # type: ignore + message = DAQJobMessageStore( + store_config={}, + data=[], + keys=[], + daq_job_info=DAQJobTest({"daq_job_type": "test"}).get_info(), + ) self.daq_job_remote.message_out = MagicMock() call_count = 0 @@ -54,12 +63,12 @@ def side_effect(): nonlocal call_count call_count += 1 if call_count >= 2: - raise Exception("Stop receive thread") - return pickle.dumps(message) + raise RuntimeError("Stop receive thread") + return self.daq_job_remote._pack_message(message) self.mock_receiver.recv.side_effect = side_effect - with self.assertRaises(Exception): + with self.assertRaises(RuntimeError): self.daq_job_remote._start_receive_thread() self.daq_job_remote.message_out.put.assert_called_once_with(message) self.assertEqual(self.daq_job_remote.message_out.put.call_count, 1) diff --git a/src/tests/test_slack.py b/src/tests/test_slack.py index d8d7069..0e10fdc 100644 --- a/src/tests/test_slack.py +++ b/src/tests/test_slack.py @@ -21,7 +21,7 @@ def test_init(self): def test_send_alert(self): alert = DAQJobMessageAlert( - daq_job=MagicMock(), + daq_job_info=MagicMock(), alert_info=DAQAlertInfo( severity=DAQAlertSeverity.ERROR, message="Test error message", @@ -34,7 +34,7 @@ def test_send_alert(self): { "fallback": "Test error message", "color": "danger", - "author_name": type(alert.daq_job).__name__, + "author_name": alert.daq_job_info.daq_job_class_name, "title": "Alert!", "fields": [ { @@ -59,7 +59,7 @@ def test_send_alert(self): def test_alert_loop(self): alert1 = DAQJobMessageAlert( - daq_job=MagicMock(), + daq_job_info=MagicMock(), alert_info=DAQAlertInfo( severity=DAQAlertSeverity.INFO, message="Test info message", @@ -67,7 +67,7 @@ def test_alert_loop(self): date=datetime(2023, 10, 1, 12, 0, 0), ) alert2 = DAQJobMessageAlert( - daq_job=MagicMock(), + daq_job_info=MagicMock(), alert_info=DAQAlertInfo( severity=DAQAlertSeverity.WARNING, message="Test warning message",