Skip to content

Commit

Permalink
Mixnet: Sphinx packet builder for mix clients (#47)
Browse files Browse the repository at this point in the history
  • Loading branch information
youngjoon-lee authored Jan 15, 2024
1 parent 3caf8d8 commit 78d02e8
Show file tree
Hide file tree
Showing 17 changed files with 313 additions and 933 deletions.
9 changes: 7 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
name: Spec tests

on: [pull_request]
on:
pull_request:
branches:
- "*"
push:
branches: [master]

jobs:
build:
Expand All @@ -12,7 +17,7 @@ jobs:
uses: actions/setup-python@v5
with:
# Semantic version range syntax or exact version of a Python version
python-version: '3.11'
python-version: '3.x'
- name: Install dependencies
run: pip install -r requirements.txt
- name: Run tests
Expand Down
4 changes: 4 additions & 0 deletions mixnet/mixnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
X25519PrivateKey,
X25519PublicKey,
)
from pysphinx.node import Node

from mixnet.bls import BlsPrivateKey, BlsPublicKey
from mixnet.fisheryates import FisherYates
Expand Down Expand Up @@ -60,6 +61,9 @@ def identity_public_key(self) -> BlsPublicKey:
def encryption_public_key(self) -> X25519PublicKey:
return self.encryption_private_key.public_key()

def sphinx_node(self) -> Node:
return Node(self.encryption_private_key, self.addr)


@dataclass
class MixnetTopology:
Expand Down
205 changes: 205 additions & 0 deletions mixnet/packet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
from __future__ import annotations

import uuid
from dataclasses import dataclass
from enum import Enum
from itertools import batched
from typing import Dict, Iterator, List, Self, Tuple, TypeAlias

from pysphinx.payload import Payload
from pysphinx.sphinx import SphinxPacket

from mixnet.mixnet import Mixnet, MixnetTopology, MixNode


class MessageFlag(Enum):
MESSAGE_FLAG_REAL = b"\x00"
MESSAGE_FLAG_DROP_COVER = b"\x01"

def bytes(self) -> bytes:
return bytes(self.value)


class PacketBuilder:
iter: Iterator[Tuple[SphinxPacket, List[MixNode]]]

def __init__(
self,
flag: MessageFlag,
message: bytes,
mixnet: Mixnet,
topology: MixnetTopology,
):
destination = mixnet.choose_mixnode()

msg_with_flag = flag.bytes() + message
# NOTE: We don't encrypt msg_with_flag for destination.
# If encryption is needed, a shared secret must be appended in front of the message along with the MessageFlag.
fragment_set = FragmentSet(msg_with_flag)

packets_and_routes = []
for fragment in fragment_set.fragments:
route = topology.generate_route()
packet = SphinxPacket.build(
fragment.bytes(),
[mixnode.sphinx_node() for mixnode in route],
destination.sphinx_node(),
)
packets_and_routes.append((packet, route))

self.iter = iter(packets_and_routes)

@classmethod
def real(cls, message: bytes, mixnet: Mixnet, topology: MixnetTopology) -> Self:
return cls(MessageFlag.MESSAGE_FLAG_REAL, message, mixnet, topology)

@classmethod
def drop_cover(
cls, message: bytes, mixnet: Mixnet, topology: MixnetTopology
) -> Self:
return cls(MessageFlag.MESSAGE_FLAG_DROP_COVER, message, mixnet, topology)

def next(self) -> Tuple[SphinxPacket, List[MixNode]]:
return next(self.iter)

@staticmethod
def parse_msg_and_flag(data: bytes) -> Tuple[MessageFlag, bytes]:
"""Remove a MessageFlag from data"""
if len(data) < 1:
raise ValueError("data is too short")

return (MessageFlag(data[0:1]), data[1:])


# Unlikely, Nym uses i32 for FragmentSetId, which may cause more collisions.
# We will use UUID until figuring out why Nym uses i32.
FragmentSetId: TypeAlias = bytes # 128bit UUID v4
FragmentId: TypeAlias = int # unsigned 8bit int in big endian

FRAGMENT_SET_ID_LENGTH: int = 16
FRAGMENT_ID_LENGTH: int = 1


@dataclass
class FragmentHeader:
"""
Contain all information for reconstructing a message that was fragmented into the same FragmentSet.
"""

set_id: FragmentSetId
total_fragments: FragmentId
fragment_id: FragmentId

SIZE: int = FRAGMENT_SET_ID_LENGTH + FRAGMENT_ID_LENGTH * 2

@staticmethod
def max_total_fragments() -> int:
return 256 # because total_fragment is u8

def bytes(self) -> bytes:
return (
self.set_id
+ self.total_fragments.to_bytes(1)
+ self.fragment_id.to_bytes(1)
)

@classmethod
def from_bytes(cls, data: bytes) -> Self:
if len(data) != cls.SIZE:
raise ValueError("Invalid data length", len(data))

return cls(data[:16], int.from_bytes(data[16:17]), int.from_bytes(data[17:18]))


@dataclass
class FragmentSet:
"""
Represent a set of Fragments that can be reconstructed to a single original message.
Note that the maximum number of fragments in a FragmentSet is limited for now.
"""

fragments: List[Fragment]

MAX_FRAGMENTS: int = FragmentHeader.max_total_fragments()

def __init__(self, message: bytes):
"""
Build a FragmentSet by chunking a message into Fragments.
"""
chunked_messages = chunks(message, Fragment.MAX_PAYLOAD_SIZE)
# For now, we don't support more than max_fragments() fragments.
# If needed, we can devise the FragmentSet chaining to support larger messages, like Nym.
if len(chunked_messages) > self.MAX_FRAGMENTS:
raise ValueError(f"Too long message: {len(chunked_messages)} chunks")

set_id = uuid.uuid4().bytes
self.fragments = [
Fragment(FragmentHeader(set_id, len(chunked_messages), i), chunk)
for i, chunk in enumerate(chunked_messages)
]


@dataclass
class Fragment:
"""Represent a piece of data that can be transformed to a single SphinxPacket"""

header: FragmentHeader
body: bytes

MAX_PAYLOAD_SIZE: int = Payload.max_plain_payload_size() - FragmentHeader.SIZE

def bytes(self) -> bytes:
return self.header.bytes() + self.body

@classmethod
def from_bytes(cls, data: bytes) -> Self:
header = FragmentHeader.from_bytes(data[: FragmentHeader.SIZE])
body = data[FragmentHeader.SIZE :]
return cls(header, body)


@dataclass
class MessageReconstructor:
fragmentSets: Dict[FragmentSetId, FragmentSetReconstructor]

def __init__(self):
self.fragmentSets = {}

def add(self, fragment: Fragment) -> bytes | None:
if fragment.header.set_id not in self.fragmentSets:
self.fragmentSets[fragment.header.set_id] = FragmentSetReconstructor(
fragment.header.total_fragments
)

msg = self.fragmentSets[fragment.header.set_id].add(fragment)
if msg is not None:
del self.fragmentSets[fragment.header.set_id]
return msg


@dataclass
class FragmentSetReconstructor:
total_fragments: FragmentId
fragments: Dict[FragmentId, Fragment]

def __init__(self, total_fragments: FragmentId):
self.total_fragments = total_fragments
self.fragments = {}

def add(self, fragment: Fragment) -> bytes | None:
self.fragments[fragment.header.fragment_id] = fragment
if len(self.fragments) == self.total_fragments:
return self.build_message()
else:
return None

def build_message(self) -> bytes:
message = b""
for i in range(self.total_fragments):
message += self.fragments[FragmentId(i)].body
return message


def chunks(data: bytes, size: int) -> List[bytes]:
return list(map(bytes, batched(data, size)))
Empty file removed mixnet/sphinx/__init__.py
Empty file.
22 changes: 0 additions & 22 deletions mixnet/sphinx/const.py

This file was deleted.

23 changes: 0 additions & 23 deletions mixnet/sphinx/crypto.py

This file was deleted.

Empty file removed mixnet/sphinx/header/__init__.py
Empty file.
Loading

0 comments on commit 78d02e8

Please sign in to comment.