Skip to content

Commit

Permalink
Optimize CIDR aggregation to improve performance and reduce memory us…
Browse files Browse the repository at this point in the history
…age (#20)
  • Loading branch information
zarvd authored Oct 7, 2024
1 parent e67f49b commit 4589bb0
Show file tree
Hide file tree
Showing 3 changed files with 253 additions and 203 deletions.
339 changes: 140 additions & 199 deletions src/aggregate.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
use core::fmt;
use core::net::{Ipv4Addr, Ipv6Addr};
use core::ptr::NonNull;

use crate::{Cidr, Ipv4Cidr, Ipv6Cidr};

/// Partitions a slice of `Cidr` into separate vectors of `Ipv4Cidr` and `Ipv6Cidr`.
Expand Down Expand Up @@ -77,6 +73,53 @@ pub fn aggregate(cidrs: &[Cidr]) -> Vec<Cidr> {
v4.chain(v6).collect()
}

const fn set_bit_at<const N: usize>(mut bytes: [u8; N], i: usize) -> [u8; N] {
bytes[i / 8] |= 1 << (7 - (i % 8));
bytes
}

const fn bit_at<const N: usize>(bytes: [u8; N], i: usize) -> u8 {
bytes[i / 8] >> (7 - i % 8) & 1
}

fn is_adjacent<const N: usize>(b1: [u8; N], b2: [u8; N], i: usize) -> bool {
if bit_at(b1, i - 1) == 0 {
b2 == set_bit_at(b1, i - 1)
} else {
b1 == set_bit_at(b2, i - 1)
}
}

fn merge_adjacent_ipv4(p1: Ipv4Cidr, p2: Ipv4Cidr) -> Option<Ipv4Cidr> {
if p1.bits() != p2.bits() || p1.bits() == 0 {
return None;
}
let bits = p1.bits();
let p1_bytes = p1.network_addr().octets();
let p2_bytes = p2.network_addr().octets();

if is_adjacent(p1_bytes, p2_bytes, bits as usize) {
Some((p1_bytes, bits - 1).try_into().unwrap())
} else {
None
}
}

fn merge_adjacent_ipv6(p1: Ipv6Cidr, p2: Ipv6Cidr) -> Option<Ipv6Cidr> {
if p1.bits() != p2.bits() || p1.bits() == 0 {
return None;
}
let bits = p1.bits();
let p1_bytes = p1.network_addr().octets();
let p2_bytes = p2.network_addr().octets();

if is_adjacent(p1_bytes, p2_bytes, bits as usize) {
Some((p1_bytes, bits - 1).try_into().unwrap())
} else {
None
}
}

/// Aggregates a list of IPv4 CIDR ranges into a minimal set of non-overlapping ranges.
///
/// # Examples
Expand All @@ -100,13 +143,32 @@ pub fn aggregate_ipv4(cidrs: &[Ipv4Cidr]) -> Vec<Ipv4Cidr> {
return cidrs.to_vec();
}

let mut tree = Tree::<Ipv4Cidr>::new();
let mut cidrs = cidrs.to_vec();
cidrs.sort_unstable();
for cidr in cidrs {
tree.insert(cidr);
cidrs.sort_by_key(|v| v.network_addr());

let mut rv = vec![cidrs[0]];

for cidr in cidrs.into_iter().skip(1) {
if rv[rv.len() - 1].contains_cidr(&cidr) {
continue;
}
rv.push(cidr);

while rv.len() >= 2 {
let p1 = rv[rv.len() - 1];
let p2 = rv[rv.len() - 2];
match merge_adjacent_ipv4(p1, p2) {
Some(p) => {
rv.pop().unwrap();
rv.pop().unwrap();
rv.push(p);
}
None => break,
}
}
}
tree.list()

rv
}

/// Aggregates a list of IPv6 CIDR ranges into a minimal set of non-overlapping ranges.
Expand All @@ -131,212 +193,91 @@ pub fn aggregate_ipv6(cidrs: &[Ipv6Cidr]) -> Vec<Ipv6Cidr> {
return cidrs.to_vec();
}

let mut tree = Tree::<Ipv6Cidr>::new();
let mut cidrs = cidrs.to_vec();
cidrs.sort_unstable();
for cidr in cidrs {
tree.insert(cidr);
}
tree.list()
}
cidrs.sort_by_key(|v| v.network_addr());

struct Node<T> {
cidr: T,
is_masked: bool,
parent: Option<NonNull<Node<T>>>,
left: Option<NonNull<Node<T>>>,
right: Option<NonNull<Node<T>>>,
}
let mut rv = vec![cidrs[0]];

impl<T> Node<T> {
#[inline]
fn new(parent: Option<NonNull<Node<T>>>, cidr: T) -> NonNull<Self> {
let boxed = Box::new(Self {
parent,
cidr,
is_masked: false,
left: None,
right: None,
});

let ptr = Box::into_raw(boxed);
NonNull::new(ptr).unwrap()
}

#[inline]
fn get_or_new_left_child<F>(&mut self, f: F) -> NonNull<Self>
where
F: FnOnce() -> NonNull<Self>,
{
*self.left.get_or_insert_with(f)
}

#[inline]
fn get_or_new_right_child<F>(&mut self, f: F) -> NonNull<Self>
where
F: FnOnce() -> NonNull<Self>,
{
*self.right.get_or_insert_with(f)
}

#[inline]
fn clear_children(&mut self) {
if let Some(left) = self.left.take() {
let _ = unsafe { Box::from_raw(left.as_ptr()) };
for cidr in cidrs.into_iter().skip(1) {
if rv[rv.len() - 1].contains_cidr(&cidr) {
continue;
}
if let Some(right) = self.right.take() {
let _ = unsafe { Box::from_raw(right.as_ptr()) };
}
}
}

impl<T> Drop for Node<T> {
fn drop(&mut self) {
self.clear_children();
}
}

struct Tree<T> {
root: NonNull<Node<T>>,
}

impl<T> Drop for Tree<T> {
fn drop(&mut self) {
unsafe {
let _ = Box::from_raw(self.root.as_ptr());
}
}
}

impl<T> Tree<T>
where
T: Copy + fmt::Debug,
{
fn pruning(node: NonNull<Node<T>>) {
let mut parent = {
let p = unsafe { node.as_ref() };
p.parent
};

while let Some(mut node) = parent {
let p = unsafe { node.as_mut() };
let mut masked = 0;
if let Some(left) = p.left {
let l = unsafe { left.as_ref() };
if l.is_masked {
masked += 1;
rv.push(cidr);

while rv.len() >= 2 {
let p1 = rv[rv.len() - 1];
let p2 = rv[rv.len() - 2];
match merge_adjacent_ipv6(p1, p2) {
Some(p) => {
rv.pop().unwrap();
rv.pop().unwrap();
rv.push(p);
}
None => break,
}
if let Some(right) = p.right {
let r = unsafe { right.as_ref() };
if r.is_masked {
masked += 1;
}
}

if masked < 2 {
break;
}
p.is_masked = true;
parent = p.parent;
}
}

pub fn list(&self) -> Vec<T> {
use std::collections::VecDeque;

let mut rv = vec![];
let mut q = VecDeque::new();

q.push_back(self.root);

while let Some(node) = q.pop_front() {
let p = unsafe { node.as_ref() };
if p.is_masked {
rv.push(p.cidr);
continue;
}
if let Some(left) = p.left {
q.push_back(left);
}
if let Some(right) = p.right {
q.push_back(right);
}
}
rv
}
rv
}

impl Tree<Ipv4Cidr> {
#[inline]
pub fn new() -> Self {
Self {
root: Node::new(None, Ipv4Cidr::from_ip(Ipv4Addr::UNSPECIFIED, 0).unwrap()),
}
#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_set_bit_at() {
assert_eq!(set_bit_at([0x00, 0x00], 0), [0b1000_0000, 0x00]);
assert_eq!(set_bit_at([0x00, 0x00], 1), [0b0100_0000, 0x00]);
assert_eq!(set_bit_at([0x00, 0x00], 2), [0b0010_0000, 0x00]);
assert_eq!(set_bit_at([0x00, 0x00], 3), [0b0001_0000, 0x00]);
assert_eq!(set_bit_at([0x00, 0x00], 4), [0b0000_1000, 0x00]);
assert_eq!(set_bit_at([0x00, 0x00], 5), [0b0000_0100, 0x00]);
assert_eq!(set_bit_at([0x00, 0x00], 6), [0b0000_0010, 0x00]);
assert_eq!(set_bit_at([0x00, 0x00], 7), [0b0000_0001, 0x00]);
assert_eq!(set_bit_at([0x00, 0x00], 8), [0x00, 0b1000_0000]);
assert_eq!(set_bit_at([0x00, 0x00], 9), [0x00, 0b0100_0000]);
assert_eq!(set_bit_at([0x00, 0x00], 10), [0x00, 0b0010_0000]);
assert_eq!(set_bit_at([0x00, 0x00], 11), [0x00, 0b0001_0000]);
assert_eq!(set_bit_at([0x00, 0x00], 12), [0x00, 0b0000_1000]);
assert_eq!(set_bit_at([0x00, 0x00], 13), [0x00, 0b0000_0100]);
assert_eq!(set_bit_at([0x00, 0x00], 14), [0x00, 0b0000_0010]);
assert_eq!(set_bit_at([0x00, 0x00], 15), [0x00, 0b0000_0001]);
}

pub fn insert(&mut self, cidr: Ipv4Cidr) {
let bytes = u32::from_be_bytes(cidr.octets());

let mut node = self.root;
for i in 0..cidr.bits() {
let p = unsafe { node.as_mut() };

if p.is_masked {
break;
}

let bit = (bytes >> (31 - i)) & 1;
let f = || Node::new(Some(node), Ipv4Cidr::new(cidr.octets(), i + 1).unwrap());
node = if bit == 0 {
p.get_or_new_left_child(f)
} else {
p.get_or_new_right_child(f)
}
}

let p = unsafe { node.as_mut() };
p.is_masked = true;
p.clear_children();
Self::pruning(node);
}
}

impl Tree<Ipv6Cidr> {
#[inline]
pub fn new() -> Self {
Self {
root: Node::new(None, Ipv6Cidr::from_ip(Ipv6Addr::UNSPECIFIED, 0).unwrap()),
}
#[test]
fn test_bit_at() {
assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 0), 0);
assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 1), 1);
assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 2), 0);
assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 3), 1);
assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 4), 0);
assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 5), 1);
assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 6), 0);
assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 7), 1);
assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 8), 1);
assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 9), 0);
assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 10), 1);
assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 11), 0);
assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 12), 1);
assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 13), 0);
assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 14), 1);
assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 15), 0);
}

pub fn insert(&mut self, cidr: Ipv6Cidr) {
let bytes = u128::from_be_bytes(cidr.octets());
#[test]
fn test_is_adjacent() {
assert!(is_adjacent([0b1010_1010], [0b1010_1011], 8));
assert!(is_adjacent([0b1010_1011], [0b1010_1010], 8));

let mut node = self.root;
for i in 0..cidr.bits() {
let p = unsafe { node.as_mut() };
if p.is_masked {
break;
}

let bit = (bytes >> (31 - i)) & 1;
let f = || {
Node::new(
Some(node),
Ipv6Cidr::from_ip(cidr.network_addr(), i + 1).unwrap(),
)
};
node = if bit == 0 {
p.get_or_new_left_child(f)
} else {
p.get_or_new_right_child(f)
}
}
assert!(is_adjacent(
[0b1010_1010, 0b1000_0000],
[0b1010_1010, 0x00],
9
));

let p = unsafe { node.as_mut() };
p.is_masked = true;
p.clear_children();
Self::pruning(node);
assert!(!is_adjacent([0b1010_1010], [0b1010_1010], 8));
assert!(!is_adjacent([0b1010_1001], [0b1010_1010], 8));
assert!(!is_adjacent([0b1010_1000], [0b1010_1010], 8));
}
}
Loading

0 comments on commit 4589bb0

Please sign in to comment.