diff --git a/mixnet/node.py b/mixnet/node.py index 054b85ba..defd5b36 100644 --- a/mixnet/node.py +++ b/mixnet/node.py @@ -27,8 +27,10 @@ # 32-byte that represents an IP address and a port of a mix node. NodeAddress: TypeAlias = bytes -InboundSocket: TypeAlias = "queue.Queue[SphinxPacket]" -OutboundSocket: TypeAlias = "queue.Queue[Tuple[NodeAddress, SphinxPacket | Payload]]" +PacketQueue: TypeAlias = "queue.Queue[Tuple[NodeAddress, SphinxPacket]]" +PacketPayloadQueue: TypeAlias = ( + "queue.Queue[Tuple[NodeAddress, SphinxPacket | Payload]]" +) @dataclass @@ -49,8 +51,8 @@ def sphinx_node(self) -> Node: def start( self, delay_rate_per_min: int, - inbound_socket: InboundSocket, - outbound_socket: OutboundSocket, + inbound_socket: PacketQueue, + outbound_socket: PacketPayloadQueue, ) -> MixNodeRunner: thread = MixNodeRunner( self.encryption_private_key, @@ -74,8 +76,8 @@ def __init__( self, encryption_private_key: X25519PrivateKey, delay_rate_per_min: int, # Poisson rate parameter: mu - inbound_socket: InboundSocket, - outbound_socket: OutboundSocket, + inbound_socket: PacketQueue, + outbound_socket: PacketPayloadQueue, ): super().__init__() self.encryption_private_key = encryption_private_key @@ -89,7 +91,7 @@ def run(self) -> None: # In the real implementation, consider implementing this in asynchronous if possible, # to approximate a M/M/inf queue while True: - packet = self.inbound_socket.get() + _, packet = self.inbound_socket.get() thread = MixNodePacketProcessor( packet, self.encryption_private_key, @@ -124,7 +126,7 @@ def __init__( packet: SphinxPacket, encryption_private_key: X25519PrivateKey, delay_rate_per_min: int, # Poisson rate parameter: mu - outbound_socket: OutboundSocket, + outbound_socket: PacketPayloadQueue, num_processing: AtomicInt, ): super().__init__() diff --git a/mixnet/test_node.py b/mixnet/test_node.py index 6cebb4da..8783e9d7 100644 --- a/mixnet/test_node.py +++ b/mixnet/test_node.py @@ -12,7 +12,7 @@ from mixnet.bls import generate_bls from mixnet.mixnet import Mixnet, MixnetTopology -from mixnet.node import InboundSocket, MixNode, OutboundSocket +from mixnet.node import MixNode, NodeAddress, PacketPayloadQueue, PacketQueue from mixnet.packet import PacketBuilder from mixnet.poisson import poisson_interval_sec, poisson_mean_interval_sec from mixnet.utils import random_bytes @@ -29,8 +29,8 @@ def test_mixnode_runner_emission_rate(self): the rate of outputs should be `lambda`. """ mixnet, topology = self.init() - inbound_socket: InboundSocket = queue.Queue() - outbound_socket: OutboundSocket = queue.Queue() + inbound_socket: PacketQueue = queue.Queue() + outbound_socket: PacketPayloadQueue = queue.Queue() packet, route = PacketBuilder.real(b"msg", mixnet, topology).next() @@ -43,7 +43,13 @@ def test_mixnode_runner_emission_rate(self): emission_rate_per_min = 120 # lambda (= 2msg/sec) sender = threading.Thread( target=self.send_packets, - args=(inbound_socket, packet, packet_count, emission_rate_per_min), + args=( + inbound_socket, + packet, + route[0].addr, + packet_count, + emission_rate_per_min, + ), ) sender.daemon = True sender.start() @@ -82,14 +88,15 @@ def test_mixnode_runner_emission_rate(self): @staticmethod def send_packets( - inbound_socket: InboundSocket, + inbound_socket: PacketQueue, packet: SphinxPacket, + node_addr: NodeAddress, cnt: int, rate_per_min: int, ): for _ in range(cnt): time.sleep(poisson_interval_sec(rate_per_min)) - inbound_socket.put(packet) + inbound_socket.put((node_addr, packet)) @staticmethod def init() -> Tuple[Mixnet, MixnetTopology]: