diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 19cf8503..77c6d41a 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,6 +1,11 @@ name: Spec tests -on: [pull_request] +on: + pull_request: + branches: + - "*" + push: + branches: [master] jobs: build: @@ -12,7 +17,7 @@ jobs: uses: actions/setup-python@v5 with: # Semantic version range syntax or exact version of a Python version - python-version: '3.11' + python-version: '3.x' - name: Install dependencies run: pip install -r requirements.txt - name: Run tests diff --git a/mixnet/mixnet.py b/mixnet/mixnet.py index e71ea5a6..fbc1ece9 100644 --- a/mixnet/mixnet.py +++ b/mixnet/mixnet.py @@ -8,6 +8,7 @@ X25519PrivateKey, X25519PublicKey, ) +from pysphinx.node import Node from mixnet.bls import BlsPrivateKey, BlsPublicKey from mixnet.fisheryates import FisherYates @@ -60,6 +61,9 @@ def identity_public_key(self) -> BlsPublicKey: def encryption_public_key(self) -> X25519PublicKey: return self.encryption_private_key.public_key() + def sphinx_node(self) -> Node: + return Node(self.encryption_private_key, self.addr) + @dataclass class MixnetTopology: diff --git a/mixnet/packet.py b/mixnet/packet.py new file mode 100644 index 00000000..50b721b8 --- /dev/null +++ b/mixnet/packet.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +import uuid +from dataclasses import dataclass +from enum import Enum +from itertools import batched +from typing import Dict, Iterator, List, Self, Tuple, TypeAlias + +from pysphinx.payload import Payload +from pysphinx.sphinx import SphinxPacket + +from mixnet.mixnet import Mixnet, MixnetTopology, MixNode + + +class MessageFlag(Enum): + MESSAGE_FLAG_REAL = b"\x00" + MESSAGE_FLAG_DROP_COVER = b"\x01" + + def bytes(self) -> bytes: + return bytes(self.value) + + +class PacketBuilder: + iter: Iterator[Tuple[SphinxPacket, List[MixNode]]] + + def __init__( + self, + flag: MessageFlag, + message: bytes, + mixnet: Mixnet, + topology: MixnetTopology, + ): + destination = mixnet.choose_mixnode() + + msg_with_flag = flag.bytes() + message + # NOTE: We don't encrypt msg_with_flag for destination. + # If encryption is needed, a shared secret must be appended in front of the message along with the MessageFlag. + fragment_set = FragmentSet(msg_with_flag) + + packets_and_routes = [] + for fragment in fragment_set.fragments: + route = topology.generate_route() + packet = SphinxPacket.build( + fragment.bytes(), + [mixnode.sphinx_node() for mixnode in route], + destination.sphinx_node(), + ) + packets_and_routes.append((packet, route)) + + self.iter = iter(packets_and_routes) + + @classmethod + def real(cls, message: bytes, mixnet: Mixnet, topology: MixnetTopology) -> Self: + return cls(MessageFlag.MESSAGE_FLAG_REAL, message, mixnet, topology) + + @classmethod + def drop_cover( + cls, message: bytes, mixnet: Mixnet, topology: MixnetTopology + ) -> Self: + return cls(MessageFlag.MESSAGE_FLAG_DROP_COVER, message, mixnet, topology) + + def next(self) -> Tuple[SphinxPacket, List[MixNode]]: + return next(self.iter) + + @staticmethod + def parse_msg_and_flag(data: bytes) -> Tuple[MessageFlag, bytes]: + """Remove a MessageFlag from data""" + if len(data) < 1: + raise ValueError("data is too short") + + return (MessageFlag(data[0:1]), data[1:]) + + +# Unlikely, Nym uses i32 for FragmentSetId, which may cause more collisions. +# We will use UUID until figuring out why Nym uses i32. +FragmentSetId: TypeAlias = bytes # 128bit UUID v4 +FragmentId: TypeAlias = int # unsigned 8bit int in big endian + +FRAGMENT_SET_ID_LENGTH: int = 16 +FRAGMENT_ID_LENGTH: int = 1 + + +@dataclass +class FragmentHeader: + """ + Contain all information for reconstructing a message that was fragmented into the same FragmentSet. + """ + + set_id: FragmentSetId + total_fragments: FragmentId + fragment_id: FragmentId + + SIZE: int = FRAGMENT_SET_ID_LENGTH + FRAGMENT_ID_LENGTH * 2 + + @staticmethod + def max_total_fragments() -> int: + return 256 # because total_fragment is u8 + + def bytes(self) -> bytes: + return ( + self.set_id + + self.total_fragments.to_bytes(1) + + self.fragment_id.to_bytes(1) + ) + + @classmethod + def from_bytes(cls, data: bytes) -> Self: + if len(data) != cls.SIZE: + raise ValueError("Invalid data length", len(data)) + + return cls(data[:16], int.from_bytes(data[16:17]), int.from_bytes(data[17:18])) + + +@dataclass +class FragmentSet: + """ + Represent a set of Fragments that can be reconstructed to a single original message. + + Note that the maximum number of fragments in a FragmentSet is limited for now. + """ + + fragments: List[Fragment] + + MAX_FRAGMENTS: int = FragmentHeader.max_total_fragments() + + def __init__(self, message: bytes): + """ + Build a FragmentSet by chunking a message into Fragments. + """ + chunked_messages = chunks(message, Fragment.MAX_PAYLOAD_SIZE) + # For now, we don't support more than max_fragments() fragments. + # If needed, we can devise the FragmentSet chaining to support larger messages, like Nym. + if len(chunked_messages) > self.MAX_FRAGMENTS: + raise ValueError(f"Too long message: {len(chunked_messages)} chunks") + + set_id = uuid.uuid4().bytes + self.fragments = [ + Fragment(FragmentHeader(set_id, len(chunked_messages), i), chunk) + for i, chunk in enumerate(chunked_messages) + ] + + +@dataclass +class Fragment: + """Represent a piece of data that can be transformed to a single SphinxPacket""" + + header: FragmentHeader + body: bytes + + MAX_PAYLOAD_SIZE: int = Payload.max_plain_payload_size() - FragmentHeader.SIZE + + def bytes(self) -> bytes: + return self.header.bytes() + self.body + + @classmethod + def from_bytes(cls, data: bytes) -> Self: + header = FragmentHeader.from_bytes(data[: FragmentHeader.SIZE]) + body = data[FragmentHeader.SIZE :] + return cls(header, body) + + +@dataclass +class MessageReconstructor: + fragmentSets: Dict[FragmentSetId, FragmentSetReconstructor] + + def __init__(self): + self.fragmentSets = {} + + def add(self, fragment: Fragment) -> bytes | None: + if fragment.header.set_id not in self.fragmentSets: + self.fragmentSets[fragment.header.set_id] = FragmentSetReconstructor( + fragment.header.total_fragments + ) + + msg = self.fragmentSets[fragment.header.set_id].add(fragment) + if msg is not None: + del self.fragmentSets[fragment.header.set_id] + return msg + + +@dataclass +class FragmentSetReconstructor: + total_fragments: FragmentId + fragments: Dict[FragmentId, Fragment] + + def __init__(self, total_fragments: FragmentId): + self.total_fragments = total_fragments + self.fragments = {} + + def add(self, fragment: Fragment) -> bytes | None: + self.fragments[fragment.header.fragment_id] = fragment + if len(self.fragments) == self.total_fragments: + return self.build_message() + else: + return None + + def build_message(self) -> bytes: + message = b"" + for i in range(self.total_fragments): + message += self.fragments[FragmentId(i)].body + return message + + +def chunks(data: bytes, size: int) -> List[bytes]: + return list(map(bytes, batched(data, size))) diff --git a/mixnet/sphinx/__init__.py b/mixnet/sphinx/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/mixnet/sphinx/const.py b/mixnet/sphinx/const.py deleted file mode 100644 index 15d143f2..00000000 --- a/mixnet/sphinx/const.py +++ /dev/null @@ -1,22 +0,0 @@ -# k in the Sphinx paper -SECURITY_PARAMETER = 16 -# r in the Sphinx paper -# In this specification, the max number of mix nodes in a route is limited to this value. -MAX_PATH_LENGTH = 5 -# The length of node address which contains an IP address and a port. -NODE_ADDRESS_LENGTH = 2 * SECURITY_PARAMETER -# The length of flag that represents the type of routing information (forward-hop or final-hop) -FLAG_LENGTH = 1 - -VERSION_LENGTH = 3 -VERSION = b"\x00\x00\x00" - -# In our architecture, SURB is not used. -# But, for the consistency with Nym's Sphinx implementation, keep this field in the Sphinx header. -SURB_IDENTIFIER_LENGTH = SECURITY_PARAMETER -SURB_IDENTIFIER = b"\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01\x01" - -# In our architecture, delays are determined by each mix node (not by a packet sender). -# But, for the consistency with Nym's Sphinx implementation, keep the delay field in the Sphinx header. -DELAY_LENGTH = 8 -DELAY = b"\x00\x00\x00\x00\x00\x00\x00\x00" diff --git a/mixnet/sphinx/crypto.py b/mixnet/sphinx/crypto.py deleted file mode 100644 index 1194b1cf..00000000 --- a/mixnet/sphinx/crypto.py +++ /dev/null @@ -1,23 +0,0 @@ -from cryptography.hazmat.primitives import hashes, hmac -from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes - - -def aes128ctr(data: bytes, key: bytes, nonce: bytes) -> bytes: - encryptor = Cipher(algorithms.AES128(key), modes.CTR(nonce)).encryptor() - return encryptor.update(data) + encryptor.finalize() - - -def compute_hmac_sha256(data: bytes, key: bytes) -> bytes: - h = hmac.HMAC(key, hashes.SHA256()) - h.update(data) - return h.finalize() - - -def lioness_encrypt(data: bytes, key: bytes) -> bytes: - # TODO: Couldn't find a lioness package that works with the latest Python. Implement it. - return data - - -def lioness_decrypt(data: bytes, key: bytes) -> bytes: - # TODO: Couldn't find a lioness package that works with the latest Python. Implement it. - return data diff --git a/mixnet/sphinx/header/__init__.py b/mixnet/sphinx/header/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/mixnet/sphinx/header/header.py b/mixnet/sphinx/header/header.py deleted file mode 100644 index b4a86610..00000000 --- a/mixnet/sphinx/header/header.py +++ /dev/null @@ -1,116 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import List, Self, Tuple - -from cryptography.hazmat.primitives.asymmetric.x25519 import ( - X25519PrivateKey, - X25519PublicKey, -) - -from mixnet.mixnet import MixNode, NodeAddress -from mixnet.sphinx.header.keys import KeyMaterial, RoutingKeys -from mixnet.sphinx.header.routing import EncapsulatedRoutingInformation, Filler - - -@dataclass -class SphinxHeader: - """ - A Sphinx header contains an encapsulated routing information - and a shared secret that can be used to unwrap one layer of the encapsulated routing information. - """ - - shared_pubkey: X25519PublicKey - routing_info: EncapsulatedRoutingInformation - - @classmethod - def build( - cls, - initial_ephemeral_privkey: X25519PrivateKey, - route: List[MixNode], - destination: MixNode, - ) -> Tuple[Self, List[bytes]]: - """ - Construct a SphinxHeader by encapsulating all routing information - and keys that can be used to encrypt a payload. - """ - key_material = KeyMaterial.derive(initial_ephemeral_privkey, route) - filler = Filler.build(key_material.routing_keys) - routing_info = EncapsulatedRoutingInformation.build( - route, destination, key_material.routing_keys, filler - ) - payload_keys = [ - routing_key.payload_key for routing_key in key_material.routing_keys - ] - return (cls(key_material.initial_ephemeral_pubkey, routing_info), payload_keys) - - def process( - self, private_key: X25519PrivateKey - ) -> ProcessedForwardHopHeader | ProcessedFinalHopHeader: - """ - Unwrap one layer of encapsulated routing information using private_key. - - If there are other encapsulated layers left after being unwrapped, this method returns ProcessedForwardHopHeader. - If not, this returns ProcessedFinalHopHeader. - """ - routing_keys = self.compute_routing_keys(self.shared_pubkey, private_key) - - assert self.routing_info.integrity_mac.verify( - self.routing_info.encrypted_routing_info.value, - routing_keys.header_integrity_hmac_key, - ) - - routing_info_and_addr = self.routing_info.encrypted_routing_info.unwrap( - routing_keys.stream_cipher_key - ) - encapsulated_routing_info = routing_info_and_addr[0] - next_node_address = routing_info_and_addr[1] - - if encapsulated_routing_info is not None: - new_shared_pubkey = KeyMaterial.blind_shared_pubkey( - self.shared_pubkey, routing_keys.blinding_factor - ) - return ProcessedForwardHopHeader( - SphinxHeader(new_shared_pubkey, encapsulated_routing_info), - next_node_address, - routing_keys.payload_key, - ) - else: - return ProcessedFinalHopHeader(next_node_address, routing_keys.payload_key) - - @staticmethod - def compute_routing_keys( - shared_pubkey: X25519PublicKey, private_key: X25519PrivateKey - ) -> RoutingKeys: - """ - Derive RoutingKeys from a shared key created by Diffie-Hellman key exchange between shared_pubkey and private_key. - """ - dh_shared_key = private_key.exchange(shared_pubkey) - return RoutingKeys.derive(dh_shared_key) - - -@dataclass -class ProcessedForwardHopHeader: - """ - A forward-hop header unwrapped from SphinxHeader - - This class contains another SphinxHeader to be forwarded to the next mix node, - and a payload key for the current mix node to decrypt one layer of payload encryption. - """ - - next_header: SphinxHeader - next_node_address: NodeAddress - payload_key: bytes - - -@dataclass -class ProcessedFinalHopHeader: - """ - A final-hop header unwrapped from SphinxHeader - - This class contains a payload key for the current mix node to decrypt the last layer of payload encryption, - and a destination address to which the decrypted payload will be delivered. - """ - - destination_address: NodeAddress - payload_key: bytes diff --git a/mixnet/sphinx/header/keys.py b/mixnet/sphinx/header/keys.py deleted file mode 100644 index f32065b4..00000000 --- a/mixnet/sphinx/header/keys.py +++ /dev/null @@ -1,99 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import List, Self - -from cryptography.hazmat.primitives import hashes -from cryptography.hazmat.primitives.asymmetric.x25519 import ( - X25519PrivateKey, - X25519PublicKey, -) -from cryptography.hazmat.primitives.kdf.hkdf import HKDF - -from mixnet.mixnet import MixNode - - -@dataclass -class KeyMaterial: - """ - Contain a list of RoutingKeys for all mix nodes in the route, - and a shared secret that will be contained in a SphinxHeader for the first mix node in the route. - """ - - initial_ephemeral_pubkey: X25519PublicKey - routing_keys: List[RoutingKeys] - - @classmethod - def derive( - cls, initial_ephemeral_privkey: X25519PrivateKey, route: List[MixNode] - ) -> Self: - """ - Derive KeyMaterial for route using initial_ephemeral_privkey provided. - """ - initial_ephemeral_pubkey = initial_ephemeral_privkey.public_key() - - routing_keys = [] - accumulated_privkey = initial_ephemeral_privkey - for node in route: - dh_shared_key = accumulated_privkey.exchange(node.encryption_public_key()) - node_routing_keys = RoutingKeys.derive(dh_shared_key) - - # TODO: find a proper library for Ristretto operations - # https://github.com/nymtech/sphinx/blob/ca107d94360cdf8bbfbdb12fe5320ed74f80e40c/src/header/keys.rs#L128-L128 - # blinding_factor_scalar = Scalar.from_bytes_mod_order(node_routing_keys.blinding_factor) - # accumulated_privkey = product(accumulated_privkey, blinding_factor_scalar) - - routing_keys.append(node_routing_keys) - - return cls(initial_ephemeral_pubkey, routing_keys) - - @staticmethod - def blind_shared_pubkey( - shared_pubkey: X25519PublicKey, blinding_factor: bytes - ) -> X25519PublicKey: - """ - Blind shared_pubkey to derive a next public key. - """ - # TODO: find a proper library for Ristretto operations - # https://github.com/nymtech/sphinx/blob/ca107d94360cdf8bbfbdb12fe5320ed74f80e40c/src/header/mod.rs#L236-L236 - # For now, we're skipping blinding because we don't accumulate a private key using blinding factor - # when deriving RoutingKeys. - return shared_pubkey - - -# Adopted from https://github.com/nymtech/sphinx/blob/ca107d94360cdf8bbfbdb12fe5320ed74f80e40c/src/constants.rs#L26-L26 -HKDF_INPUT_SEED = b"Dwste mou enan moxlo arketa makru kai ena upomoxlio gia na ton topothetisw kai tha kinisw thn gh." - - -@dataclass -class RoutingKeys: - """ - Contain all keys for a mix node in the route. - """ - - # For Sphinx header encryption (AES-128) - stream_cipher_key: bytes - # For HMAC integrity authentication - header_integrity_hmac_key: bytes - # For payload encryption (ChaCha20) - payload_key: bytes - # For deriving a shared key for a next mix node, combining with the previous ephemeral private key - blinding_factor: bytes - - @classmethod - def derive(cls, dh_shared_key: bytes) -> Self: - """ - Derive all keys from dh_shared_key using HKDF-SHA256. - """ - derived_key = HKDF( - algorithm=hashes.SHA256(), length=256, salt=None, info=HKDF_INPUT_SEED - ).derive(dh_shared_key) - assert len(derived_key) == 256 - - stream_cipher_key = derived_key[0:16] # 16bytes == 128bits - header_integrity_hmac_key = derived_key[16:32] # 16bytes - payload_key = derived_key[32:224] # 192bytes - blinding_factor = derived_key[224:] # 32bytes - return cls( - stream_cipher_key, header_integrity_hmac_key, payload_key, blinding_factor - ) diff --git a/mixnet/sphinx/header/mac.py b/mixnet/sphinx/header/mac.py deleted file mode 100644 index 859fd83f..00000000 --- a/mixnet/sphinx/header/mac.py +++ /dev/null @@ -1,36 +0,0 @@ -from dataclasses import dataclass -from typing import Self - -from mixnet.sphinx.const import SECURITY_PARAMETER -from mixnet.sphinx.crypto import compute_hmac_sha256 - - -@dataclass -class IntegrityHmac: - """ - This class represents a HMAC-SHA256 that can be used for integrity authentication. - """ - - value: bytes - - def __init__(self, value: bytes): - """Override the default constructor to assert the size of value""" - assert len(value) == IntegrityHmac.size() - self.value = value - - @staticmethod - def size() -> int: - return SECURITY_PARAMETER - - @classmethod - def compute(cls, data: bytes, key: bytes) -> Self: - """ - Build IntegrityHmac using data and key. - """ - return cls(compute_hmac_sha256(data, key)[: cls.size()]) - - def verify(self, data: bytes, key: bytes) -> bool: - """ - Verify a HMAC computed from data and key matches with the expected HMAC. - """ - return self.value == compute_hmac_sha256(data, key)[: self.size()] diff --git a/mixnet/sphinx/header/routing.py b/mixnet/sphinx/header/routing.py deleted file mode 100644 index 3fdb6b4a..00000000 --- a/mixnet/sphinx/header/routing.py +++ /dev/null @@ -1,418 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import List, Optional, Self, Tuple, TypeAlias - -from mixnet.mixnet import MixNode, NodeAddress -from mixnet.sphinx.const import ( - DELAY, - DELAY_LENGTH, - FLAG_LENGTH, - MAX_PATH_LENGTH, - NODE_ADDRESS_LENGTH, - SURB_IDENTIFIER, - SURB_IDENTIFIER_LENGTH, - VERSION, - VERSION_LENGTH, -) -from mixnet.sphinx.crypto import aes128ctr -from mixnet.sphinx.header.keys import RoutingKeys -from mixnet.sphinx.header.mac import IntegrityHmac -from mixnet.sphinx.utils import zero_bytes -from mixnet.utils import random_bytes - - -@dataclass -class EncapsulatedRoutingInformation: - """ - An encapsulated routing information that can be unwrapped by a certain mix node in the route. - """ - - # An encrypted routing information that can be decrypted by a certain mix node in the route. - encrypted_routing_info: EncryptedRoutingInformation - # For integrity authentication - integrity_mac: IntegrityHmac - - @classmethod - def build( - cls, - route: List[MixNode], - destination: MixNode, - routing_keys: List[RoutingKeys], - filler: Filler, - ) -> Self: - """ - Build EncapsulatedRoutingInformation by building sub-EncapsulatedRoutingInformation recursively. - """ - assert len(route) > 0 - assert len(route) == len(routing_keys) - - final_keys = routing_keys[-1] - encapsulated_destination_routing_info = cls.for_final_hop( - destination, final_keys, filler, len(route) - ) - - return cls.for_forward_hops( - encapsulated_destination_routing_info, route, routing_keys - ) - - @classmethod - def for_final_hop( - cls, - destination: MixNode, - routing_keys: RoutingKeys, - filler: Filler, - route_len: int, - ) -> Self: - """ - Build EncapsulatedRoutingInformation for the final mix node in the route that will forward payload to the destination. - - filler is used for the undistinguishability between forward-hop headers and a final-hop header. - For more details, please see Filler. - """ - encrypted_routing_info = ( - FinalRoutingInformation.build(destination.addr) - .add_padding(route_len) - .encrypt(routing_keys.stream_cipher_key) - .combine_with_filler(filler) - ) - integrity_mac = IntegrityHmac.compute( - encrypted_routing_info.value, routing_keys.header_integrity_hmac_key - ) - return cls(encrypted_routing_info, integrity_mac) - - @classmethod - def for_forward_hops( - cls, - encapsulated_destination_routing_info: Self, - route: List[MixNode], - routing_keys: List[RoutingKeys], - ) -> Self: - """ - Build EncapsulatedRoutingInformation for all mix nodes except the final mix node in the route. - """ - next_encapsulated_routing_info = encapsulated_destination_routing_info - - # skip the first mixnodes because the sender will forward the packet to the first mixnode directly - for i in reversed(range(1, len(route))): - node = route[i] - routing_key = routing_keys[i - 1] - - routing_info = RoutingInformation.build( - node.addr, next_encapsulated_routing_info - ) - encrypted_routing_info = routing_info.encrypt(routing_key.stream_cipher_key) - integrity_mac = IntegrityHmac.compute( - encrypted_routing_info.value, routing_key.header_integrity_hmac_key - ) - next_encapsulated_routing_info = cls(encrypted_routing_info, integrity_mac) - - return next_encapsulated_routing_info - - def bytes(self) -> bytes: - return self.integrity_mac.value + self.encrypted_routing_info.value - - @classmethod - def from_bytes(cls, data: bytes) -> Self: - return cls( - EncryptedRoutingInformation(data[IntegrityHmac.size() :]), - IntegrityHmac(data[: IntegrityHmac.size()]), - ) - - -RoutingFlag: TypeAlias = bytes # 1byte - -ROUTING_FLAG_FORWARD_HOP: RoutingFlag = b"\x01" -ROUTING_FLAG_FINAL_HOP: RoutingFlag = b"\x02" - - -@dataclass -class EncryptedRoutingInformation: - "An encrypted routing information using a private key of a certain mix node." - - value: bytes - - def __init__(self, value: bytes): - """Override the default constructor to assert the size of value.""" - assert len(value) == EncryptedRoutingInformation.size() - self.value = value - - @staticmethod - def size() -> int: - """ - To make the size of Sphinx header constant, the size of this class is constant. - """ - return RoutingInformation.meta_size() * MAX_PATH_LENGTH - - def truncate(self) -> TruncatedRoutingInformation: - """ - Truncate the encrypted routing information as much as the size of a single filler. - - This method can be used when this routing information is about to be encapsulated once more. - For more details, please see Filler. - """ - return TruncatedRoutingInformation( - self.value[: len(self.value) - Filler.one_step_size()] - ) - - def unwrap( - self, stream_cipher_key: bytes - ) -> Tuple[Optional[EncapsulatedRoutingInformation], NodeAddress]: - """ - Decrypt the routing information and return a next node address and a next EncapsulatedRoutingInformation if exists. - """ - # Since this EncryptedRoutingInformation has been truncated when being encapsulated, - # add zero padding as much as the truncated bytes, before decrypting it. - padding = zero_bytes(Filler.one_step_size()) - decrypted = decrypt(self.value + padding, stream_cipher_key) - - flag = RoutingFlag(decrypted[0:FLAG_LENGTH]) - if flag == ROUTING_FLAG_FORWARD_HOP: - i = FLAG_LENGTH + VERSION_LENGTH - node_address = decrypted[i : i + NODE_ADDRESS_LENGTH] - i += NODE_ADDRESS_LENGTH + DELAY_LENGTH - next_hop_integrity_mac = IntegrityHmac( - decrypted[i : i + IntegrityHmac.size()] - ) - i += IntegrityHmac.size() - encrypted_next_routing_info = EncryptedRoutingInformation(decrypted[i:]) - return ( - EncapsulatedRoutingInformation( - encrypted_next_routing_info, next_hop_integrity_mac - ), - node_address, - ) - elif flag == ROUTING_FLAG_FINAL_HOP: - i = FLAG_LENGTH + VERSION_LENGTH - destination_address = decrypted[i : i + NODE_ADDRESS_LENGTH] - i += NODE_ADDRESS_LENGTH - _ = decrypted[i : i + SURB_IDENTIFIER_LENGTH] - return (None, destination_address) - else: - assert False # Unknown flag - - -@dataclass -class RoutingInformation: - """ - Represent a forward-hop routing information not encrypted and not encapsulated - """ - - flag: RoutingFlag - node_address: NodeAddress - header_integrity_mac: bytes - next_routing_info: TruncatedRoutingInformation - - @staticmethod - def meta_size() -> int: - # 60 bytes in total - return ( - FLAG_LENGTH - + VERSION_LENGTH - + NODE_ADDRESS_LENGTH - + DELAY_LENGTH - + IntegrityHmac.size() - ) - - @classmethod - def build( - cls, - node: NodeAddress, - next_encapsulated_routing_info: EncapsulatedRoutingInformation, - ) -> Self: - return cls( - ROUTING_FLAG_FORWARD_HOP, - node, - next_encapsulated_routing_info.integrity_mac.value, - next_encapsulated_routing_info.encrypted_routing_info.truncate(), - ) - - def encrypt(self, key: bytes) -> EncryptedRoutingInformation: - body = ( - self.flag - + VERSION - + self.node_address - + DELAY - + self.header_integrity_mac - + self.next_routing_info.value - ) - return EncryptedRoutingInformation(encrypt(body, key)) - - -@dataclass -class TruncatedRoutingInformation: - """ - Represent an encrypted routing information truncated as much as a single filler. - """ - - value: bytes - - def __init__(self, value: bytes): - """Override the default constructor to assert the size of value.""" - assert len(value) == TruncatedRoutingInformation.size() - self.value = value - - @staticmethod - def size() -> int: - return EncryptedRoutingInformation.size() - Filler.one_step_size() - - -@dataclass -class FinalRoutingInformation: - """ - Represent a forward-hop routing information not encrypted and not encapsulated - """ - - flag: RoutingFlag - destination_address: NodeAddress - - @classmethod - def build(cls, destination: NodeAddress) -> Self: - return cls(ROUTING_FLAG_FINAL_HOP, destination) - - @staticmethod - def size() -> int: - # 52 bytes in total - return ( - FLAG_LENGTH + VERSION_LENGTH + NODE_ADDRESS_LENGTH + SURB_IDENTIFIER_LENGTH - ) - - def add_padding(self, route_len: int) -> PaddedFinalRoutingInformation: - """ - To make the final encrypted routing information (that will contain this routing information) - have the same size as upper-layer encrypted routing information, - add random-byte padding to the tail of FinalRoutingInformation. - """ - padding = random_bytes(PaddedFinalRoutingInformation.padding_size(route_len)) - return PaddedFinalRoutingInformation( - self.flag + VERSION + self.destination_address + SURB_IDENTIFIER + padding - ) - - -@dataclass -class PaddedFinalRoutingInformation: - """ - A random-byte padded FinalRoutingInformation - """ - - value: bytes - - @staticmethod - def padding_size(route_len: int) -> int: - """ - The point of this padding is making the size of EncryptedRoutingInformation - (that will contain this final routing information) - the same as other EncryptedRoutingInformations that contain RoutingInformation. - """ - return ( - EncryptedRoutingInformation.size() - - Filler.size(route_len) - - FinalRoutingInformation.size() - ) - - def encrypt(self, key: bytes) -> EncryptedPaddedFinalRoutingInformation: - return EncryptedPaddedFinalRoutingInformation(encrypt(self.value, key)) - - -@dataclass -class EncryptedPaddedFinalRoutingInformation: - value: bytes - - def combine_with_filler(self, filler: Filler) -> EncryptedRoutingInformation: - """ - Because the size of this class is smaller than EncryptedRoutingInformation, - add fillers to create EncryptedRoutingInformation from this value. - """ - return EncryptedRoutingInformation(self.value + filler.value) - - -@dataclass -class Filler: - """ - This class represents a set of multiple fillers, 1 less than the length of mix route. - A single filler has the same size as a single RoutingInformation. - - A single filler is used to make the routing information that has been unwrapped once - have the same size as the routing information before unwrapped. - - For the same purpose, a set of multiple fillers (this class) is meant to be - appended to a EncryptedPaddedFinalRoutingInformation. - """ - - value: bytes - - def __init__(self, value: bytes): - """Override the default constructor to assert the size of value.""" - assert len(value) % self.one_step_size() == 0 - self.value = value - - @staticmethod - def size(route_len: int) -> int: - # Note that this is not one_step_size * route_len - # because the information of the first mix node in the route doesn't need to be - # encapsulated in a Sphinx packet. - # A packet sender always know the address of the first mix node. - return Filler.one_step_size() * (route_len - 1) - - @staticmethod - def one_step_size() -> int: - """A size of a single filler, which is the same as the size of RoutingInformation""" - return RoutingInformation.meta_size() - - @classmethod - def build(cls, routing_keys: List[RoutingKeys]) -> Self: - assert len(routing_keys) <= MAX_PATH_LENGTH - - filler = b"" - # except the last key - for routing_key in routing_keys[: len(routing_keys) - 1]: - filler += zero_bytes(Filler.one_step_size()) - - # This process is the same as encrypting RoutingInformation to create EncryptedRoutingInformation, - # so that a single filler can be easily reproduced and appended to the EncapsulatedRoutingInformation - # when it is unwrapped. - # - # The implementation of the regular encryption can be found at the end of this file. - rand = pseudo_random(routing_key.stream_cipher_key) - assert len(filler) <= len(rand) - # XOR with the last len(filler) bytes of rand - filler = xor(filler, rand[len(rand) - len(filler) :]) - - assert len(filler) == Filler.size(len(routing_keys)) - return cls(filler) - - -AES128CTR_NONCE = zero_bytes(16) - - -def pseudo_random(key: bytes) -> bytes: - """ - Return a pseudo-random bytes with length EncryptedRoutingInformation + a single filler - generated using AES128-CTR with a constant nonce. - """ - return aes128ctr( - zero_bytes(EncryptedRoutingInformation.size() + Filler.one_step_size()), - key, - AES128CTR_NONCE, - ) - - -def encrypt(data: bytes, key: bytes) -> bytes: - """ - data is encrypted by XOR with a pseudo-random bytes generated using key, - so that it can be decrypted later by XOR with the same pseudo-random bytes from the same key. - """ - rand = pseudo_random(key) - assert len(data) <= len(rand) - return xor(data, rand[: len(data)]) # XOR with truncating rand - - -def decrypt(data: bytes, key: bytes) -> bytes: - # Decryption is the same as encryption - # because a common pseudo random value is used for XOR - return encrypt(data, key) - - -def xor(ba1: bytes, ba2: bytes) -> bytes: - """Bitwise XOR operation""" - return bytes([_a ^ _b for _a, _b in zip(ba1, ba2)]) diff --git a/mixnet/sphinx/payload.py b/mixnet/sphinx/payload.py deleted file mode 100644 index 0d1a7ecf..00000000 --- a/mixnet/sphinx/payload.py +++ /dev/null @@ -1,70 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import List, Self - -from mixnet.sphinx.const import SECURITY_PARAMETER -from mixnet.sphinx.crypto import lioness_decrypt, lioness_encrypt -from mixnet.sphinx.utils import zero_bytes - -# For the packet indistinguishability, the size of payload (padded) is a constant. -DEFAULT_PAYLOAD_SIZE = 1024 -PAYLOAD_TRAILING_PADDING_INDICATOR = b"\x01" - - -@dataclass -class Payload: - data: bytes - - @classmethod - def build(cls, plain_payload: bytes, payload_keys: List[bytes]) -> Self: - payload = cls.add_padding(plain_payload) - for payload_key in reversed(payload_keys): - payload = lioness_encrypt(payload, payload_key) - return cls(payload) - - @staticmethod - def add_padding(plain_payload: bytes) -> bytes: - """ - Add leading and trailing padding to a plain payload - - This padding mechanism is the same as Nym's Sphinx implementation. - """ - assert len(plain_payload) <= Payload.max_plain_payload_size() - - padded = ( - zero_bytes(SECURITY_PARAMETER) - + plain_payload - + PAYLOAD_TRAILING_PADDING_INDICATOR - + zero_bytes( - DEFAULT_PAYLOAD_SIZE - - SECURITY_PARAMETER - - len(plain_payload) - - len(PAYLOAD_TRAILING_PADDING_INDICATOR) - ) - ) - assert len(padded) == DEFAULT_PAYLOAD_SIZE - return padded - - @staticmethod - def max_plain_payload_size() -> int: - return ( - DEFAULT_PAYLOAD_SIZE - - SECURITY_PARAMETER - - len(PAYLOAD_TRAILING_PADDING_INDICATOR) - ) - - def unwrap(self, payload_key: bytes) -> Payload: - """Unwrap a single layer of encryption""" - return Payload(lioness_decrypt(self.data, payload_key)) - - def recover_plain_playload(self) -> bytes: - """ - After Payload has been unwrapped required number of times, - this method must be called to parse the unwrapped payload into - the original payload by removing leading/trailing paddings. - """ - assert self.data.startswith(zero_bytes(SECURITY_PARAMETER)) - indicator_idx = self.data.rfind(PAYLOAD_TRAILING_PADDING_INDICATOR) - assert indicator_idx != -1 - return self.data[SECURITY_PARAMETER:indicator_idx] diff --git a/mixnet/sphinx/sphinx.py b/mixnet/sphinx/sphinx.py deleted file mode 100644 index 5851add6..00000000 --- a/mixnet/sphinx/sphinx.py +++ /dev/null @@ -1,87 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import List, Self - -from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey - -from mixnet.mixnet import MixNode, NodeAddress -from mixnet.sphinx.header.header import ( - ProcessedFinalHopHeader, - ProcessedForwardHopHeader, - SphinxHeader, -) -from mixnet.sphinx.payload import Payload - - -@dataclass -class SphinxPacket: - """ - A Sphinx packet that will be sent directly through network sockets - """ - - header: SphinxHeader - payload: Payload - - @classmethod - def build( - cls, - message: bytes, - route: List[MixNode], - destination: MixNode, - ) -> Self: - """ - This method is a constructor for packet senders. - - A packet sender has to determine a mix route and a mix destination. - - A message must fit into the capacity of a single Sphinx packet. - For details, please see Payload. - """ - header_and_payload_keys = SphinxHeader.build( - X25519PrivateKey.generate(), route, destination - ) - header = header_and_payload_keys[0] - payload_keys = header_and_payload_keys[1] - - payload = Payload.build(message, payload_keys) - - return cls(header, payload) - - def process( - self, private_key: X25519PrivateKey - ) -> ProcessedForwardHopPacket | ProcessedFinalHopPacket: - """ - Unwrap one layer of encapsulated routing information in the Sphinx packet using private_key. - - If there are other encapsulated layers left after being unwrapped, this method returns ProcessedForwardHopPacket. - If not, this returns ProcessedFinalHopPacket. - """ - processed_header = self.header.process(private_key) - if isinstance(processed_header, ProcessedForwardHopHeader): - return ProcessedForwardHopPacket( - SphinxPacket( - processed_header.next_header, - self.payload.unwrap(processed_header.payload_key), - ), - processed_header.next_node_address, - ) - elif isinstance(processed_header, ProcessedFinalHopHeader): - return ProcessedFinalHopPacket( - processed_header.destination_address, - self.payload.unwrap(processed_header.payload_key), - ) - else: - assert False # unknown type of processed header - - -@dataclass -class ProcessedForwardHopPacket: - next_packet: SphinxPacket - next_node_address: NodeAddress - - -@dataclass -class ProcessedFinalHopPacket: - destination_node_address: NodeAddress - payload: Payload diff --git a/mixnet/sphinx/test_sphinx.py b/mixnet/sphinx/test_sphinx.py deleted file mode 100644 index 290d3d46..00000000 --- a/mixnet/sphinx/test_sphinx.py +++ /dev/null @@ -1,54 +0,0 @@ -from unittest import TestCase - -from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey - -from mixnet.bls import generate_bls -from mixnet.mixnet import Mixnet, MixNode -from mixnet.utils import random_bytes -from mixnet.sphinx.sphinx import ( - ProcessedFinalHopPacket, - ProcessedForwardHopPacket, - SphinxPacket, -) - - -class TestSphinx(TestCase): - def test_sphinx(self): - mixnet = Mixnet( - [ - MixNode(generate_bls(), X25519PrivateKey.generate(), random_bytes(32)) - for _ in range(12) - ] - ) - topology = mixnet.build_topology(b"entropy", 3, 3) - - msg = random_bytes(500) - route = topology.generate_route() - destination = mixnet.choose_mixnode() - - packet = SphinxPacket.build(msg, route, destination) - - # Process packet with the first mix node in the route - processed_packet = packet.process(route[0].encryption_private_key) - if not isinstance(processed_packet, ProcessedForwardHopPacket): - self.fail() - self.assertEqual(processed_packet.next_node_address, route[1].addr) - - # Process packet with the second mix node in the route - processed_packet = processed_packet.next_packet.process( - route[1].encryption_private_key - ) - if not isinstance(processed_packet, ProcessedForwardHopPacket): - self.fail() - self.assertEqual(processed_packet.next_node_address, route[2].addr) - - # Process packet with the third mix node in the route - processed_packet = processed_packet.next_packet.process( - route[2].encryption_private_key - ) - if not isinstance(processed_packet, ProcessedFinalHopPacket): - self.fail() - self.assertEqual(processed_packet.destination_node_address, destination.addr) - - # Verify message as a destination - self.assertEqual(processed_packet.payload.recover_plain_playload(), msg) diff --git a/mixnet/sphinx/utils.py b/mixnet/sphinx/utils.py deleted file mode 100644 index b57e3a24..00000000 --- a/mixnet/sphinx/utils.py +++ /dev/null @@ -1,3 +0,0 @@ -def zero_bytes(size: int) -> bytes: - assert size >= 0 - return bytes([0 for _ in range(size)]) diff --git a/mixnet/test_packet.py b/mixnet/test_packet.py new file mode 100644 index 00000000..676da7e1 --- /dev/null +++ b/mixnet/test_packet.py @@ -0,0 +1,91 @@ +from typing import List, Tuple +from unittest import TestCase + +from cryptography.hazmat.primitives.asymmetric.x25519 import X25519PrivateKey +from pysphinx.sphinx import ProcessedFinalHopPacket, SphinxPacket + +from mixnet.bls import generate_bls +from mixnet.mixnet import Mixnet, MixnetTopology, MixNode +from mixnet.packet import ( + Fragment, + MessageFlag, + MessageReconstructor, + PacketBuilder, +) +from mixnet.utils import random_bytes + + +class TestPacket(TestCase): + def test_real_packet(self): + mixnet, topology = self.init() + + msg = random_bytes(3500) + builder = PacketBuilder.real(msg, mixnet, topology) + packet0, route0 = builder.next() + packet1, route1 = builder.next() + packet2, route2 = builder.next() + packet3, route3 = builder.next() + self.assertRaises(StopIteration, builder.next) + + reconstructor = MessageReconstructor() + self.assertIsNone( + reconstructor.add(self.process_packet(packet1, route1)), + ) + self.assertIsNone( + reconstructor.add(self.process_packet(packet3, route3)), + ) + self.assertIsNone( + reconstructor.add(self.process_packet(packet2, route2)), + ) + msg_with_flag = reconstructor.add(self.process_packet(packet0, route0)) + assert msg_with_flag is not None + self.assertEqual( + PacketBuilder.parse_msg_and_flag(msg_with_flag), + (MessageFlag.MESSAGE_FLAG_REAL, msg), + ) + + def test_cover_packet(self): + mixnet, topology = self.init() + + msg = b"cover" + builder = PacketBuilder.drop_cover(msg, mixnet, topology) + packet, route = builder.next() + self.assertRaises(StopIteration, builder.next) + + reconstructor = MessageReconstructor() + msg_with_flag = reconstructor.add(self.process_packet(packet, route)) + assert msg_with_flag is not None + self.assertEqual( + PacketBuilder.parse_msg_and_flag(msg_with_flag), + (MessageFlag.MESSAGE_FLAG_DROP_COVER, msg), + ) + + @staticmethod + def init() -> Tuple[Mixnet, MixnetTopology]: + mixnet = Mixnet( + [ + MixNode( + generate_bls(), + X25519PrivateKey.generate(), + random_bytes(32), + ) + for _ in range(12) + ] + ) + topology = mixnet.build_topology(b"entropy", 3, 3) + return mixnet, topology + + @staticmethod + def process_packet(packet: SphinxPacket, route: List[MixNode]) -> Fragment: + processed = packet.process(route[0].encryption_private_key) + if isinstance(processed, ProcessedFinalHopPacket): + return Fragment.from_bytes(processed.payload.recover_plain_playload()) + else: + processed = processed + for node in route[1:]: + p = processed.next_packet.process(node.encryption_private_key) + if isinstance(p, ProcessedFinalHopPacket): + return Fragment.from_bytes(p.payload.recover_plain_playload()) + else: + processed = p + assert False diff --git a/requirements.txt b/requirements.txt index 9f987b63..0013983a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,9 @@ -blspy==1.0.16 +blspy==2.0.2 cffi==1.16.0 cryptography==41.0.7 -numpy==1.26.2 +numpy==1.26.3 pycparser==2.21 -scipy==1.10.1 +pysphinx==0.0.1 +scipy==1.11.4 +setuptools==69.0.3 +wheel==0.42.0