Skip to content

Commit

Permalink
refactor: no longer reference DAQJob in messages to make it serializable
Browse files Browse the repository at this point in the history
- we now also use `dataclasses_json` for serialization, instead of `pickle`
  • Loading branch information
furkan-bilgin committed Oct 19, 2024
1 parent 8d3eb7c commit 3668bb1
Show file tree
Hide file tree
Showing 14 changed files with 83 additions and 40 deletions.
2 changes: 1 addition & 1 deletion src/daq/alert/alert_slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def send_webhook(self, alert: DAQJobMessageAlert):
{
"fallback": alert.alert_info.message,
"color": ALERT_SEVERITY_TO_SLACK_COLOR[alert.alert_info.severity],
"author_name": type(alert.daq_job).__name__,
"author_name": alert.daq_job_info.daq_job_class_name,
"title": "Alert!",
"fields": [
{
Expand Down
4 changes: 2 additions & 2 deletions src/daq/alert/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from dataclasses_json import DataClassJsonMixin

from daq.base import DAQJob
from daq.base import DAQJob, DAQJobInfo
from daq.models import DAQJobMessage


Expand All @@ -24,7 +24,7 @@ class DAQAlertInfo(DataClassJsonMixin):

@dataclass
class DAQJobMessageAlert(DAQJobMessage):
daq_job: DAQJob
daq_job_info: DAQJobInfo
date: datetime
alert_info: DAQAlertInfo

Expand Down
23 changes: 22 additions & 1 deletion src/daq/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
import threading
import uuid
from dataclasses import dataclass
from queue import Empty, Queue
from typing import Any

from daq.models import DAQJobMessage, DAQJobMessageStop, DAQJobStopError
from daq.models import DAQJobConfig, DAQJobMessage, DAQJobMessageStop, DAQJobStopError

daq_job_instance_id = 0
daq_job_instance_id_lock = threading.Lock()
Expand All @@ -17,6 +18,7 @@ class DAQJob:
message_in: Queue[DAQJobMessage]
message_out: Queue[DAQJobMessage]
instance_id: int
unique_id: str

_logger: logging.Logger

Expand All @@ -33,6 +35,7 @@ def __init__(self, config: Any):
self.message_out = Queue()

self._should_stop = False
self.unique_id = str(uuid.uuid4())

def consume(self):
# consume messages from the queue
Expand Down Expand Up @@ -61,6 +64,16 @@ def handle_message(self, message: "DAQJobMessage") -> bool:
def start(self):
raise NotImplementedError

def get_info(self) -> "DAQJobInfo":
return DAQJobInfo(
daq_job_type=self.config.daq_job_type
if isinstance(self.config, DAQJobConfig)
else self.config["daq_job_type"],
daq_job_class_name=type(self).__name__,
unique_id=self.unique_id,
instance_id=self.instance_id,
)

def __del__(self):
self._logger.info("DAQ job is being deleted")

Expand All @@ -69,3 +82,11 @@ def __del__(self):
class DAQJobThread:
daq_job: DAQJob
thread: threading.Thread


@dataclass
class DAQJobInfo:
daq_job_type: str
daq_job_class_name: str # has type(self).__name__
unique_id: str
instance_id: int
2 changes: 1 addition & 1 deletion src/daq/jobs/caen/n1081b.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _send_store_message(self, data: dict, section):
self.message_out.put(
DAQJobMessageStore(
store_config=self.config.store_config,
daq_job=self,
daq_job_info=self.get_info(),
prefix=section,
keys=keys,
data=[values],
Expand Down
4 changes: 2 additions & 2 deletions src/daq/jobs/handle_alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def handle_message(self, message: DAQJobMessageAlert) -> bool:
data_to_send = [
[
get_unix_timestamp_ms(message.date),
type(message.daq_job).__name__,
message.daq_job_info.daq_job_class_name,
message.alert_info.severity,
message.alert_info.message,
]
Expand All @@ -44,7 +44,7 @@ def handle_message(self, message: DAQJobMessageAlert) -> bool:
self.message_out.put(
DAQJobMessageStore(
store_config=self.config.store_config,
daq_job=self,
daq_job_info=self.get_info(),
keys=keys,
data=data_to_send,
)
Expand Down
2 changes: 1 addition & 1 deletion src/daq/jobs/handle_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def unpack_record(record: DAQJobStatsRecord):
self.message_out.put(
DAQJobMessageStore(
store_config=self.config.store_config,
daq_job=self,
daq_job_info=self.get_info(),
keys=keys,
data=data_to_send,
)
Expand Down
2 changes: 1 addition & 1 deletion src/daq/jobs/healthcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def handle_checks(self):
def send_alert(self, item: HealthcheckItem):
self.message_out.put(
DAQJobMessageAlert(
daq_job=self,
daq_job_info=self.get_info(),
date=datetime.now(),
alert_info=item.alert_info,
)
Expand Down
30 changes: 26 additions & 4 deletions src/daq/jobs/remote.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import pickle
import json
import threading
import time
from dataclasses import dataclass
Expand Down Expand Up @@ -27,6 +27,9 @@ class DAQJobRemote(DAQJob):
allowed_message_in_types = [DAQJobMessage] # accept all message types
config_type = DAQJobRemoteConfig
config: DAQJobRemoteConfig
_zmq_local: zmq.Socket
_zmq_remote: zmq.Socket
_message_class_cache: dict

def __init__(self, config: DAQJobRemoteConfig):
super().__init__(config)
Expand All @@ -35,21 +38,24 @@ def __init__(self, config: DAQJobRemoteConfig):
self._zmq_remote = self._zmq_context.socket(zmq.PULL)
self._zmq_local.connect(config.zmq_local_url)
self._zmq_remote.connect(config.zmq_remote_url)
self._message_class_cache = {}

self._receive_thread = threading.Thread(
target=self._start_receive_thread, daemon=True
)
self._message_class_cache = {
x.__name__: x for x in DAQJobMessage.__subclasses__()
}

def handle_message(self, message: DAQJobMessage) -> bool:
print(type(message))
self._zmq_local.send(pickle.dumps(message))
self._zmq_local.send(self._pack_message(message))
return True

def _start_receive_thread(self):
while True:
message = self._zmq_remote.recv()
# remote message_in -> message_out
self.message_out.put(pickle.loads(message))
self.message_out.put(self._unpack_message(message))

def start(self):
self._receive_thread.start()
Expand All @@ -60,3 +66,19 @@ def start(self):
# message_in -> remote message_out
self.consume()
time.sleep(0.1)

def _pack_message(self, message: DAQJobMessage) -> bytes:
message_type = type(message).__name__
return json.dumps([message_type, message.to_dict()]).encode("utf-8")

def _unpack_message(self, message: bytes) -> DAQJobMessage:
message_type, data = json.loads(message.decode("utf-8"))
if message_type in self._message_class_cache:
message_class = self._message_class_cache[message_type]
else:
message_class = globals()[message_type]
self._message_class_cache[message_type] = message_class

if not issubclass(message_class, DAQJobMessage):
raise Exception(f"Invalid message type: {message_type}")
return message_class.from_dict(data)
2 changes: 1 addition & 1 deletion src/daq/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class DAQJobConfig(DataClassJsonMixin):


@dataclass
class DAQJobMessage:
class DAQJobMessage(DataClassJsonMixin):
pass


Expand Down
15 changes: 3 additions & 12 deletions src/daq/store/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from dataclasses import dataclass
from typing import Any, Optional
from typing import Any

from dataclasses_json import DataClassJsonMixin

from daq.base import DAQJob
from daq.base import DAQJobInfo
from daq.models import DAQJobConfig, DAQJobMessage


Expand All @@ -19,20 +19,11 @@ class DAQJobStoreConfig(DataClassJsonMixin):
@dataclass
class DAQJobMessageStore(DAQJobMessage):
store_config: dict | DAQJobStoreConfig
daq_job: Optional[DAQJob]
daq_job_info: DAQJobInfo
keys: list[str]
data: list[list[Any]]
prefix: str | None = None

def __getstate__(self):
state = self.__dict__.copy()
del state["daq_job"]
return state

def __setstate__(self, state):
self.__dict__.update(state)
self.daq_job = None # type: ignore


@dataclass
class StorableDAQJobConfig(DAQJobConfig):
Expand Down
4 changes: 2 additions & 2 deletions src/tests/test_handle_alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_handle_message(self, mock_get_unix_timestamp_ms):

message = MagicMock(spec=DAQJobMessageAlert)
message.date = date
message.daq_job = MagicMock()
message.daq_job_info = MagicMock()
message.alert_info = MagicMock()
message.alert_info.severity = "high"
message.alert_info.message = "Test alert message"
Expand All @@ -42,7 +42,7 @@ def test_handle_message(self, mock_get_unix_timestamp_ms):
[
[
get_unix_timestamp_ms(date),
type(message.daq_job).__name__,
message.daq_job_info.daq_job_class_name,
"high",
"Test alert message",
]
Expand Down
2 changes: 1 addition & 1 deletion src/tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_send_messages_to_daq_jobs(self, mock_parse_store_config):
mock_thread.daq_job.allowed_message_in_types = [DAQJobMessageStore]
mock_thread.daq_job.message_in = Queue()
mock_message = DAQJobMessageStore(
store_config={}, keys=[], data=[], daq_job=MagicMock()
store_config={}, keys=[], data=[], daq_job_info=MagicMock()
)
daq_job_threads = [mock_thread]
daq_job_threads: list[DAQJobThread] = daq_job_threads
Expand Down
23 changes: 16 additions & 7 deletions src/tests/test_remote.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pickle
import threading
import time
import unittest
Expand Down Expand Up @@ -26,10 +25,15 @@ def setUp(self, MockZmqContext):

def test_handle_message(self):
message = DAQJobMessageStore(
store_config={}, data=[], keys=[], daq_job=DAQJobTest({})
store_config={},
data=[],
keys=[],
daq_job_info=DAQJobTest({"daq_job_type": "test"}).get_info(),
)
self.daq_job_remote.handle_message(message)
self.mock_sender.send.assert_called_once_with(pickle.dumps(message))
self.mock_sender.send.assert_called_once_with(
self.daq_job_remote._pack_message(message)
)

def test_start(self):
mock_receive_thread = MagicMock()
Expand All @@ -45,7 +49,12 @@ def stop_receive_thread():
self.daq_job_remote.start()

def test_receive_thread(self):
message = DAQJobMessageStore(store_config={}, data=[], keys=[], daq_job=None) # type: ignore
message = DAQJobMessageStore(
store_config={},
data=[],
keys=[],
daq_job_info=DAQJobTest({"daq_job_type": "test"}).get_info(),
)
self.daq_job_remote.message_out = MagicMock()

call_count = 0
Expand All @@ -54,12 +63,12 @@ def side_effect():
nonlocal call_count
call_count += 1
if call_count >= 2:
raise Exception("Stop receive thread")
return pickle.dumps(message)
raise RuntimeError("Stop receive thread")
return self.daq_job_remote._pack_message(message)

self.mock_receiver.recv.side_effect = side_effect

with self.assertRaises(Exception):
with self.assertRaises(RuntimeError):
self.daq_job_remote._start_receive_thread()
self.daq_job_remote.message_out.put.assert_called_once_with(message)
self.assertEqual(self.daq_job_remote.message_out.put.call_count, 1)
Expand Down
8 changes: 4 additions & 4 deletions src/tests/test_slack.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_init(self):

def test_send_alert(self):
alert = DAQJobMessageAlert(
daq_job=MagicMock(),
daq_job_info=MagicMock(),
alert_info=DAQAlertInfo(
severity=DAQAlertSeverity.ERROR,
message="Test error message",
Expand All @@ -34,7 +34,7 @@ def test_send_alert(self):
{
"fallback": "Test error message",
"color": "danger",
"author_name": type(alert.daq_job).__name__,
"author_name": alert.daq_job_info.daq_job_class_name,
"title": "Alert!",
"fields": [
{
Expand All @@ -59,15 +59,15 @@ def test_send_alert(self):

def test_alert_loop(self):
alert1 = DAQJobMessageAlert(
daq_job=MagicMock(),
daq_job_info=MagicMock(),
alert_info=DAQAlertInfo(
severity=DAQAlertSeverity.INFO,
message="Test info message",
),
date=datetime(2023, 10, 1, 12, 0, 0),
)
alert2 = DAQJobMessageAlert(
daq_job=MagicMock(),
daq_job_info=MagicMock(),
alert_info=DAQAlertInfo(
severity=DAQAlertSeverity.WARNING,
message="Test warning message",
Expand Down

0 comments on commit 3668bb1

Please sign in to comment.