From 4589bb027f4fd3137fef299da4c64acb15367b99 Mon Sep 17 00:00:00 2001 From: Jialun Cai Date: Mon, 7 Oct 2024 08:01:52 +0800 Subject: [PATCH] Optimize CIDR aggregation to improve performance and reduce memory usage (#20) --- src/aggregate.rs | 339 +++++++++++++++++++-------------------------- src/cidr.rs | 86 +++++++++++- tests/aggregate.rs | 31 +++++ 3 files changed, 253 insertions(+), 203 deletions(-) diff --git a/src/aggregate.rs b/src/aggregate.rs index 3cf3084..22f947d 100644 --- a/src/aggregate.rs +++ b/src/aggregate.rs @@ -1,7 +1,3 @@ -use core::fmt; -use core::net::{Ipv4Addr, Ipv6Addr}; -use core::ptr::NonNull; - use crate::{Cidr, Ipv4Cidr, Ipv6Cidr}; /// Partitions a slice of `Cidr` into separate vectors of `Ipv4Cidr` and `Ipv6Cidr`. @@ -77,6 +73,53 @@ pub fn aggregate(cidrs: &[Cidr]) -> Vec { v4.chain(v6).collect() } +const fn set_bit_at(mut bytes: [u8; N], i: usize) -> [u8; N] { + bytes[i / 8] |= 1 << (7 - (i % 8)); + bytes +} + +const fn bit_at(bytes: [u8; N], i: usize) -> u8 { + bytes[i / 8] >> (7 - i % 8) & 1 +} + +fn is_adjacent(b1: [u8; N], b2: [u8; N], i: usize) -> bool { + if bit_at(b1, i - 1) == 0 { + b2 == set_bit_at(b1, i - 1) + } else { + b1 == set_bit_at(b2, i - 1) + } +} + +fn merge_adjacent_ipv4(p1: Ipv4Cidr, p2: Ipv4Cidr) -> Option { + if p1.bits() != p2.bits() || p1.bits() == 0 { + return None; + } + let bits = p1.bits(); + let p1_bytes = p1.network_addr().octets(); + let p2_bytes = p2.network_addr().octets(); + + if is_adjacent(p1_bytes, p2_bytes, bits as usize) { + Some((p1_bytes, bits - 1).try_into().unwrap()) + } else { + None + } +} + +fn merge_adjacent_ipv6(p1: Ipv6Cidr, p2: Ipv6Cidr) -> Option { + if p1.bits() != p2.bits() || p1.bits() == 0 { + return None; + } + let bits = p1.bits(); + let p1_bytes = p1.network_addr().octets(); + let p2_bytes = p2.network_addr().octets(); + + if is_adjacent(p1_bytes, p2_bytes, bits as usize) { + Some((p1_bytes, bits - 1).try_into().unwrap()) + } else { + None + } +} + /// Aggregates a list of IPv4 CIDR ranges into a minimal set of non-overlapping ranges. /// /// # Examples @@ -100,13 +143,32 @@ pub fn aggregate_ipv4(cidrs: &[Ipv4Cidr]) -> Vec { return cidrs.to_vec(); } - let mut tree = Tree::::new(); let mut cidrs = cidrs.to_vec(); - cidrs.sort_unstable(); - for cidr in cidrs { - tree.insert(cidr); + cidrs.sort_by_key(|v| v.network_addr()); + + let mut rv = vec![cidrs[0]]; + + for cidr in cidrs.into_iter().skip(1) { + if rv[rv.len() - 1].contains_cidr(&cidr) { + continue; + } + rv.push(cidr); + + while rv.len() >= 2 { + let p1 = rv[rv.len() - 1]; + let p2 = rv[rv.len() - 2]; + match merge_adjacent_ipv4(p1, p2) { + Some(p) => { + rv.pop().unwrap(); + rv.pop().unwrap(); + rv.push(p); + } + None => break, + } + } } - tree.list() + + rv } /// Aggregates a list of IPv6 CIDR ranges into a minimal set of non-overlapping ranges. @@ -131,212 +193,91 @@ pub fn aggregate_ipv6(cidrs: &[Ipv6Cidr]) -> Vec { return cidrs.to_vec(); } - let mut tree = Tree::::new(); let mut cidrs = cidrs.to_vec(); - cidrs.sort_unstable(); - for cidr in cidrs { - tree.insert(cidr); - } - tree.list() -} + cidrs.sort_by_key(|v| v.network_addr()); -struct Node { - cidr: T, - is_masked: bool, - parent: Option>>, - left: Option>>, - right: Option>>, -} + let mut rv = vec![cidrs[0]]; -impl Node { - #[inline] - fn new(parent: Option>>, cidr: T) -> NonNull { - let boxed = Box::new(Self { - parent, - cidr, - is_masked: false, - left: None, - right: None, - }); - - let ptr = Box::into_raw(boxed); - NonNull::new(ptr).unwrap() - } - - #[inline] - fn get_or_new_left_child(&mut self, f: F) -> NonNull - where - F: FnOnce() -> NonNull, - { - *self.left.get_or_insert_with(f) - } - - #[inline] - fn get_or_new_right_child(&mut self, f: F) -> NonNull - where - F: FnOnce() -> NonNull, - { - *self.right.get_or_insert_with(f) - } - - #[inline] - fn clear_children(&mut self) { - if let Some(left) = self.left.take() { - let _ = unsafe { Box::from_raw(left.as_ptr()) }; + for cidr in cidrs.into_iter().skip(1) { + if rv[rv.len() - 1].contains_cidr(&cidr) { + continue; } - if let Some(right) = self.right.take() { - let _ = unsafe { Box::from_raw(right.as_ptr()) }; - } - } -} - -impl Drop for Node { - fn drop(&mut self) { - self.clear_children(); - } -} - -struct Tree { - root: NonNull>, -} - -impl Drop for Tree { - fn drop(&mut self) { - unsafe { - let _ = Box::from_raw(self.root.as_ptr()); - } - } -} - -impl Tree -where - T: Copy + fmt::Debug, -{ - fn pruning(node: NonNull>) { - let mut parent = { - let p = unsafe { node.as_ref() }; - p.parent - }; - - while let Some(mut node) = parent { - let p = unsafe { node.as_mut() }; - let mut masked = 0; - if let Some(left) = p.left { - let l = unsafe { left.as_ref() }; - if l.is_masked { - masked += 1; + rv.push(cidr); + + while rv.len() >= 2 { + let p1 = rv[rv.len() - 1]; + let p2 = rv[rv.len() - 2]; + match merge_adjacent_ipv6(p1, p2) { + Some(p) => { + rv.pop().unwrap(); + rv.pop().unwrap(); + rv.push(p); } + None => break, } - if let Some(right) = p.right { - let r = unsafe { right.as_ref() }; - if r.is_masked { - masked += 1; - } - } - - if masked < 2 { - break; - } - p.is_masked = true; - parent = p.parent; } } - pub fn list(&self) -> Vec { - use std::collections::VecDeque; - - let mut rv = vec![]; - let mut q = VecDeque::new(); - - q.push_back(self.root); - - while let Some(node) = q.pop_front() { - let p = unsafe { node.as_ref() }; - if p.is_masked { - rv.push(p.cidr); - continue; - } - if let Some(left) = p.left { - q.push_back(left); - } - if let Some(right) = p.right { - q.push_back(right); - } - } - rv - } + rv } -impl Tree { - #[inline] - pub fn new() -> Self { - Self { - root: Node::new(None, Ipv4Cidr::from_ip(Ipv4Addr::UNSPECIFIED, 0).unwrap()), - } +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_set_bit_at() { + assert_eq!(set_bit_at([0x00, 0x00], 0), [0b1000_0000, 0x00]); + assert_eq!(set_bit_at([0x00, 0x00], 1), [0b0100_0000, 0x00]); + assert_eq!(set_bit_at([0x00, 0x00], 2), [0b0010_0000, 0x00]); + assert_eq!(set_bit_at([0x00, 0x00], 3), [0b0001_0000, 0x00]); + assert_eq!(set_bit_at([0x00, 0x00], 4), [0b0000_1000, 0x00]); + assert_eq!(set_bit_at([0x00, 0x00], 5), [0b0000_0100, 0x00]); + assert_eq!(set_bit_at([0x00, 0x00], 6), [0b0000_0010, 0x00]); + assert_eq!(set_bit_at([0x00, 0x00], 7), [0b0000_0001, 0x00]); + assert_eq!(set_bit_at([0x00, 0x00], 8), [0x00, 0b1000_0000]); + assert_eq!(set_bit_at([0x00, 0x00], 9), [0x00, 0b0100_0000]); + assert_eq!(set_bit_at([0x00, 0x00], 10), [0x00, 0b0010_0000]); + assert_eq!(set_bit_at([0x00, 0x00], 11), [0x00, 0b0001_0000]); + assert_eq!(set_bit_at([0x00, 0x00], 12), [0x00, 0b0000_1000]); + assert_eq!(set_bit_at([0x00, 0x00], 13), [0x00, 0b0000_0100]); + assert_eq!(set_bit_at([0x00, 0x00], 14), [0x00, 0b0000_0010]); + assert_eq!(set_bit_at([0x00, 0x00], 15), [0x00, 0b0000_0001]); } - pub fn insert(&mut self, cidr: Ipv4Cidr) { - let bytes = u32::from_be_bytes(cidr.octets()); - - let mut node = self.root; - for i in 0..cidr.bits() { - let p = unsafe { node.as_mut() }; - - if p.is_masked { - break; - } - - let bit = (bytes >> (31 - i)) & 1; - let f = || Node::new(Some(node), Ipv4Cidr::new(cidr.octets(), i + 1).unwrap()); - node = if bit == 0 { - p.get_or_new_left_child(f) - } else { - p.get_or_new_right_child(f) - } - } - - let p = unsafe { node.as_mut() }; - p.is_masked = true; - p.clear_children(); - Self::pruning(node); - } -} - -impl Tree { - #[inline] - pub fn new() -> Self { - Self { - root: Node::new(None, Ipv6Cidr::from_ip(Ipv6Addr::UNSPECIFIED, 0).unwrap()), - } + #[test] + fn test_bit_at() { + assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 0), 0); + assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 1), 1); + assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 2), 0); + assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 3), 1); + assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 4), 0); + assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 5), 1); + assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 6), 0); + assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 7), 1); + assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 8), 1); + assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 9), 0); + assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 10), 1); + assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 11), 0); + assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 12), 1); + assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 13), 0); + assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 14), 1); + assert_eq!(bit_at([0b0101_0101, 0b1010_1010], 15), 0); } - pub fn insert(&mut self, cidr: Ipv6Cidr) { - let bytes = u128::from_be_bytes(cidr.octets()); + #[test] + fn test_is_adjacent() { + assert!(is_adjacent([0b1010_1010], [0b1010_1011], 8)); + assert!(is_adjacent([0b1010_1011], [0b1010_1010], 8)); - let mut node = self.root; - for i in 0..cidr.bits() { - let p = unsafe { node.as_mut() }; - if p.is_masked { - break; - } - - let bit = (bytes >> (31 - i)) & 1; - let f = || { - Node::new( - Some(node), - Ipv6Cidr::from_ip(cidr.network_addr(), i + 1).unwrap(), - ) - }; - node = if bit == 0 { - p.get_or_new_left_child(f) - } else { - p.get_or_new_right_child(f) - } - } + assert!(is_adjacent( + [0b1010_1010, 0b1000_0000], + [0b1010_1010, 0x00], + 9 + )); - let p = unsafe { node.as_mut() }; - p.is_masked = true; - p.clear_children(); - Self::pruning(node); + assert!(!is_adjacent([0b1010_1010], [0b1010_1010], 8)); + assert!(!is_adjacent([0b1010_1001], [0b1010_1010], 8)); + assert!(!is_adjacent([0b1010_1000], [0b1010_1010], 8)); } } diff --git a/src/cidr.rs b/src/cidr.rs index 14aea3f..8bb12b4 100644 --- a/src/cidr.rs +++ b/src/cidr.rs @@ -280,6 +280,29 @@ impl Ipv4Cidr { addr & mask == cidr } + /// Returns [`true`] if the CIDR block contains the given CIDR block. + /// + /// # Example + /// + /// ``` + /// use cidrs::Ipv4Cidr; + /// + /// let cidr1 = Ipv4Cidr::new([192, 168, 0, 0], 24).unwrap(); + /// let cidr2 = Ipv4Cidr::new([192, 168, 1, 0], 24).unwrap(); + /// let cidr3 = Ipv4Cidr::new([192, 168, 0, 0], 16).unwrap(); + /// + /// assert!(!cidr1.contains_cidr(&cidr2)); + /// assert!(!cidr2.contains_cidr(&cidr1)); + /// assert!(!cidr1.contains_cidr(&cidr3)); + /// assert!(!cidr2.contains_cidr(&cidr3)); + /// assert!(cidr3.contains_cidr(&cidr1)); + /// assert!(cidr3.contains_cidr(&cidr2)); + /// ``` + #[inline] + pub const fn contains_cidr(&self, other: &Self) -> bool { + self.overlaps(other) && self.bits() <= other.bits + } + /// Returns [`true`] if the CIDR block overlaps with the given CIDR block. /// /// # Examples @@ -766,7 +789,7 @@ impl Ipv6Cidr { self.bits } - /// Returns `true` if the CIDR block contains the given IPv6 address. + /// Returns [`true`] if the CIDR block contains the given IPv6 address. /// /// # Examples /// @@ -788,7 +811,30 @@ impl Ipv6Cidr { addr & mask == cidr } - /// Returns `true` if the CIDR block overlaps with the given CIDR block. + /// Returns [`true`] if the CIDR block contains the given CIDR block. + /// + /// # Examples + /// + /// ``` + /// use cidrs::Ipv6Cidr; + /// + /// let cidr1 = Ipv6Cidr::new([0x2001, 0xdb8, 0, 0, 0, 0, 0, 0], 64).unwrap(); + /// let cidr2 = Ipv6Cidr::new([0x2001, 0xdb8, 1, 0, 0, 0, 0, 0], 64).unwrap(); + /// let cidr3 = Ipv6Cidr::new([0x2001, 0xdb8, 0, 0, 0, 0, 0, 0], 32).unwrap(); + /// + /// assert!(!cidr1.contains_cidr(&cidr2)); + /// assert!(!cidr2.contains_cidr(&cidr1)); + /// assert!(!cidr1.contains_cidr(&cidr3)); + /// assert!(!cidr2.contains_cidr(&cidr3)); + /// assert!(cidr3.contains_cidr(&cidr1)); + /// assert!(cidr3.contains_cidr(&cidr2)); + /// ``` + #[inline] + pub const fn contains_cidr(&self, other: &Self) -> bool { + self.overlaps(other) && self.bits() <= other.bits + } + + /// Returns [`true`] if the CIDR block overlaps with the given CIDR block. /// /// # Examples /// @@ -1152,7 +1198,7 @@ impl Cidr { } } - /// Returns `true` if the CIDR block contains the given IP address. + /// Returns [`true`] if the CIDR block contains the given IP address. /// /// # Examples /// @@ -1177,7 +1223,39 @@ impl Cidr { } } - /// Returns `true` if the CIDR block overlaps with the given CIDR block. + /// Returns [`true`] if the CIDR block contains the given CIDR block. + /// + /// # Examples + /// + /// ``` + /// use core::net::IpAddr; + /// use cidrs::Cidr; + /// + /// let ipv4_cidr1 = Cidr::new(IpAddr::V4([192, 168, 1, 0].into()), 24).unwrap(); + /// let ipv4_cidr2 = Cidr::new(IpAddr::V4([192, 168, 0, 0].into()), 16).unwrap(); + /// + /// let ipv6_cidr1 = Cidr::new(IpAddr::V6([0x2001, 0xdb8, 0x1, 0, 0, 0, 0, 0].into()), 48).unwrap(); + /// let ipv6_cidr2 = Cidr::new(IpAddr::V6([0x2001, 0xdb8, 0, 0, 0, 0, 0, 0].into()), 32).unwrap(); + /// + /// assert!(!ipv4_cidr1.contains_cidr(&ipv4_cidr2)); + /// assert!(ipv4_cidr2.contains_cidr(&ipv4_cidr1)); + /// + /// assert!(!ipv6_cidr1.contains_cidr(&ipv6_cidr2)); + /// assert!(ipv6_cidr2.contains_cidr(&ipv6_cidr1)); + /// + /// assert!(!ipv4_cidr1.contains_cidr(&ipv6_cidr1)); + /// assert!(!ipv6_cidr1.contains_cidr(&ipv4_cidr1)); + /// ``` + #[inline] + pub const fn contains_cidr(&self, other: &Cidr) -> bool { + match (self, other) { + (Cidr::V4(lh), Cidr::V4(rh)) => lh.contains_cidr(rh), + (Cidr::V6(lh), Cidr::V6(rh)) => lh.contains_cidr(rh), + _ => false, + } + } + + /// Returns [`true`] if the CIDR block overlaps with the given CIDR block. /// /// # Examples /// diff --git a/tests/aggregate.rs b/tests/aggregate.rs index 618233b..c2fc195 100644 --- a/tests/aggregate.rs +++ b/tests/aggregate.rs @@ -56,3 +56,34 @@ fn ipv4_basic() { assert_eq!(actual, expected, "input: {input:?}"); } } + +#[test] +fn ipv4_full() { + { + let cidrs: Vec = (0..=255) + .flat_map(|i| { + (0..=255) + .map(|j| Ipv4Cidr::new([i, j, 0, 0], 16).unwrap()) + .collect::>() + }) + .collect(); + + let expected = vec![Ipv4Cidr::new([0, 0, 0, 0], 0).unwrap()]; + let actual = aggregate_ipv4(&cidrs); + assert_eq!(actual, expected); + } + + { + let cidrs: Vec = (0..=255) + .flat_map(|i| { + (0..=255) + .map(|j| Ipv4Cidr::new([192, 168, i, j], 32).unwrap()) + .collect::>() + }) + .collect(); + + let expected = vec![Ipv4Cidr::new([192, 168, 0, 0], 16).unwrap()]; + let actual = aggregate_ipv4(&cidrs); + assert_eq!(actual, expected); + } +}