diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..a6f0573 --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] +select = F,E722 +ignore = F403,F405,F541 +per-file-ignores = + */__init__.py:F401,F403 diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 91f60e2..5ae4031 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -22,12 +22,15 @@ jobs: test: needs: lint runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v4 with: - python-version: "3.x" + python-version: ${{ matrix.python-version }} - name: Install dependencies run: | pip install poetry diff --git a/README.md b/README.md index 4f50330..05df13a 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,16 @@ [![Python Version](https://img.shields.io/badge/python-3.9+-blue)](https://www.python.org) [![License](https://img.shields.io/badge/license-GPLv3-blue.svg)](https://github.com/blacklanternsecurity/radixtarget/blob/master/LICENSE) [![Black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![tests](https://github.com/blacklanternsecurity/radixtarget/actions/workflows/tests.yml/badge.svg)](https://github.com/blacklanternsecurity/radixtarget/actions/workflows/tests.yml) [![Codecov](https://codecov.io/gh/blacklanternsecurity/radixtarget/graph/badge.svg?token=7IPWMYMTGZ)](https://codecov.io/gh/blacklanternsecurity/radixtarget) -RadixTarget is a performant radix implementation designed for quick lookups of IP addresses/networks and DNS hostnames. Written in pure python and capable of roughly 100,000 lookups per second regardless of the size of the database. +RadixTarget is a performant radix implementation designed for quick lookups of IP addresses/networks and DNS hostnames. -Used by: -- [BBOT (Bighuge BLS OSINT Tool)](https://github.com/blacklanternsecurity/bbot) -- [cloudcheck](https://github.com/blacklanternsecurity/cloudcheck) +RadixTarget is: +- Written in pure python +- Capable of ~100,000 lookups per second regardless of database size +- 100% test coverage +- Used by: + - [BBOT](https://github.com/blacklanternsecurity/bbot) + - [cloudcheck](https://github.com/blacklanternsecurity/cloudcheck) +Written in pure python and capable of roughly 100,000 lookups per second regardless of database size, it's perfect for production . ### Installation ([PyPi](https://pypi.org/project/radixtarget/)) @@ -22,26 +27,26 @@ from radixtarget import RadixTarget rt = RadixTarget() # IPv4 -rt.insert("192.168.1.0/24") -rt.search("192.168.1.10") # IPv4Network("192.168.1.0/24") -rt.search("192.168.2.10") # None +rt.add("192.168.1.0/24") +rt.get("192.168.1.10") # IPv4Network("192.168.1.0/24") +rt.get("192.168.2.10") # None -# ipv6 -rt.insert("dead::/64") -rt.search("dead::beef") # IPv6Network("dead::/64") -rt.search("dead:cafe::beef") # None +# IPv6 +rt.add("dead::/64") +rt.get("dead::beef") # IPv6Network("dead::/64") +rt.get("dead:cafe::beef") # None # DNS -rt.insert("net") -rt.insert("www.example.com") -rt.insert("test.www.example.com") -rt.search("net") # "net" -rt.search("evilcorp.net") # "net" -rt.search("www.example.com") # "www.example.com" -rt.search("asdf.test.www.example.com") # "test.www.example.com" -rt.search("example.com") # None +rt.add("net") +rt.add("www.example.com") +rt.add("test.www.example.com") +rt.get("net") # "net" +rt.get("evilcorp.net") # "net" +rt.get("www.example.com") # "www.example.com" +rt.get("asdf.test.www.example.com") # "test.www.example.com" +rt.get("example.com") # None # Custom data nodes -rt.insert("evilcorp.co.uk", "custom_data") -rt.search("www.evilcorp.co.uk") # "custom_data" +rt.add("evilcorp.co.uk", "custom_data") +rt.get("www.evilcorp.co.uk") # "custom_data" ``` \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 6ef5c28..6b1092f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "radixtarget" -version = "1.1.0" +version = "2.0.0" description = "Check whether an IP address belongs to a cloud provider" authors = ["TheTechromancer"] license = "GPL-3.0" diff --git a/radixtarget/__init__.py b/radixtarget/__init__.py index 119dacb..8928687 100644 --- a/radixtarget/__init__.py +++ b/radixtarget/__init__.py @@ -1 +1 @@ -from .radixtarget import RadixTarget +from .radixtarget import Target diff --git a/radixtarget/helpers.py b/radixtarget/helpers.py new file mode 100644 index 0000000..82cafb0 --- /dev/null +++ b/radixtarget/helpers.py @@ -0,0 +1,62 @@ +import ipaddress + + +def host_size_key(host): + """ + Used for sorting by host size, so that parent dns names / ip subnets always come first + + Notes: + - we have to use str(host) to break the tie between two hosts of the same length, e.g. evilcorp.com and evilcorp.net + """ + host = make_ip(host) + if is_ip(host): + try: + # bigger IP subnets should come first + return (-host.num_addresses, str(host)) + except AttributeError: + # IP addresses default to 1 + return (1, str(host)) + # smaller domains should come first + return (len(host), str(host)) + + +def is_ip(host): + """Check if the given host is an instance of an IP address. + + Args: + host (Any): The host to check. + + Returns: + bool: True if the host is an instance of an IP address, False otherwise. + """ + return ipaddress._IPAddressBase in host.__class__.__mro__ + + +def make_ip(host): + """Convert a host to an IP network or return it as a lowercase string. + + This function checks if the provided host is a string or an IP address. + If it is not a string and not an IP address, a ValueError is raised. + If the host is a valid IP address or network, it is converted to an + ipaddress.IPv4Network or ipaddress.IPv6Network object. If the host + cannot be converted, it is returned as a lowercase string. + + Args: + host (str or ipaddress): The host to convert. + + Raises: + ValueError: If the host is not of str or ipaddress type. + + Returns: + ipaddress.IPv4Network or ipaddress.IPv6Network or str: The converted + IP network or the lowercase string representation of the host. + """ + if not isinstance(host, str): + if not is_ip(host): + raise ValueError( + f'Host "{host}" must be of str or ipaddress type, not "{type(host)}"' + ) + try: + return ipaddress.ip_network(host, strict=False) + except Exception: + return host.lower() diff --git a/radixtarget/radixtarget.py b/radixtarget/radixtarget.py index 26b621d..ff877a3 100644 --- a/radixtarget/radixtarget.py +++ b/radixtarget/radixtarget.py @@ -1,38 +1,221 @@ -import ipaddress +import copy +from hashlib import sha1 -from .tree.ip import IPRadixTree -from .tree.dns import DNSRadixTree +from radixtarget.tree.ip import IPRadixTree +from radixtarget.tree.dns import DNSRadixTree +from radixtarget.helpers import is_ip, make_ip, host_size_key -class RadixTarget: - def __init__(self): +sentinel = object() + + +class Target: + """ + A class representing a target. Can contain an unlimited number of hosts, IPs, or IP ranges. + + Attributes: + strict_scope (bool): Flag indicating whether to consider child domains in-scope. + If set to True, only the exact hosts specified and not their children are considered part of the target. + + Examples: + Basic usage + >>> target = SimpleTarget("evilcorp.com", "1.2.3.0/24") + >>> len(target) + 257 + >>> "www.evilcorp.com" in target + True + >>> "1.2.3.4" in target + True + >>> "4.3.2.1" in target + False + + Target comparison + >>> target2 = SimpleTarget("www.evilcorp.com") + >>> target2 == target + False + >>> target2 in target + True + >>> target in target2 + False + + Notes: + - If you do not want to include child subdomains, use `strict_scope=True` + """ + + def __init__(self, *targets, strict_dns_scope=False, acl_mode=False): + """ + Initialize a Target object. + + Args: + *targets: One or more targets (e.g., domain names, IP ranges) to be included in this Target. + strict_scope (bool): Whether to consider subdomains of target domains in-scope + acl_mode (bool): If a host is already in the target, don't add it unnecessarily (more efficient) + + Notes: + - The strict_scope flag can be set to restrict scope calculation to only exactly-matching hosts and not their child subdomains. + """ + self._hash = None + self.strict_dns_scope = strict_dns_scope + self.acl_mode = acl_mode self.ip_tree = IPRadixTree() - self.dns_tree = DNSRadixTree() + self.dns_tree = DNSRadixTree(strict_scope=strict_dns_scope) + self._hosts = set() + self.add(targets) + + def get(self, host, raise_error=False): + host = make_ip(host) + if is_ip(host): + return self.ip_tree.search(host, raise_error=raise_error) + else: + return self.dns_tree.search(host, raise_error=raise_error) + + def search(self, host, raise_error=False): + return self.get(host, raise_error=raise_error) + + def insert(self, t, data=None): + self.add(t, data=data) + + def add(self, t, data=None): + """ + Add a target or merge hosts from another Target object into this Target. + + Args: + t: The target to be added. It can be either a string, ipaddress object, or another Target object. + + Examples: + >>> target.add('example.com') + """ + if isinstance(t, self.__class__): + t = t.hosts + if not isinstance(t, (list, tuple, set)): + t = [t] + for single_target in sorted(t, key=host_size_key): + host = make_ip(single_target) + self.add_host(host, data=data) + + def add_host(self, host, data=None): + host = make_ip(host) + try: + result = self.search(host, raise_error=True) + except KeyError: + result = sentinel + # if we're in acl mode, we skip adding hosts that are already in the target + if self.acl_mode and result is not sentinel: + return + self._add_host(host, data=data) - def insert(self, host, data=None): - host = self.make_ip(host) - if self.is_ip(host): + def _add_host(self, host, data=None): + self._hash = None + self._hosts.add(host) + if is_ip(host): self.ip_tree.insert(host, data=data) else: self.dns_tree.insert(host, data=data) - def search(self, host): - host = self.make_ip(host) - if self.is_ip(host): - return self.ip_tree.search(host) - else: - return self.dns_tree.search(host) - - def make_ip(self, host): - if not isinstance(host, str): - if not self.is_ip(host): - raise ValueError( - f'Host "{host}" must be of str or ipaddress type, not "{type(host)}"' - ) + @property + def hosts(self): + """ + Returns all hosts in the target. + """ + return self._hosts + + @property + def sorted_hosts(self): + return sorted(self._hosts, key=host_size_key) + + def copy(self): + """ + Creates and returns a copy of the Target object, including a shallow copy of the `_events` attributes. + + Returns: + Target: A new Target object with the same attributes as the original. + A shallow copy of the `_events` dictionary is made. + + Examples: + >>> original_target = SimpleTarget("example.com") + >>> copied_target = original_target.copy() + >>> copied_target is original_target + False + >>> copied_target == original_target + True + >>> copied_target in original_target + True + >>> original_target in copied_target + True + + Notes: + - The `scan` object reference is kept intact in the copied Target object. + """ + self_copy = self.__class__( + strict_dns_scope=self.strict_dns_scope, acl_mode=self.acl_mode + ) + self_copy._hosts = set(self._hosts) + self_copy.ip_tree = copy.copy(self.ip_tree) + self_copy.dns_tree = copy.copy(self.dns_tree) + return self_copy + + def _contains(self, other): try: - return ipaddress.ip_network(host, strict=False) - except Exception: - return host.lower() + self.get(other, raise_error=True) + return True + except KeyError: + return False + + @property + def hash(self): + if self._hash is None: + # Create a new SHA-1 hash object + sha1_hash = sha1() + # Update the SHA-1 object with the hash values of each object + for host in [str(h).encode() for h in self.sorted_hosts]: + sha1_hash.update(host) + if self.strict_dns_scope: + sha1_hash.update(b"\x00") + self._hash = sha1_hash.digest() + return self._hash + + def __str__(self): + return ",".join([str(h) for h in self.hosts][:5]) + + def __iter__(self): + yield from self.hosts + + def __contains__(self, other): + # if "other" is a Target, iterate over its hosts and check if they are in self + if isinstance(other, self.__class__): + for h in other.hosts: + if not self._contains(h): + return False + return True + else: + return self._contains(other) + + def __bool__(self): + return bool(self._hosts) + + def __eq__(self, other): + return self.hash == other.hash + + def __len__(self): + """ + Calculates and returns the total number of hosts within this target, not counting duplicate hosts. + + Returns: + int: The total number of unique hosts present within the target's `_hosts`. + + Examples: + >>> target = SimpleTarget("evilcorp.com", "1.2.3.0/24") + >>> len(target) + 257 - def is_ip(self, host): - return ipaddress._IPAddressBase in host.__class__.__mro__ + Notes: + - If a host is represented as an IP network, all individual IP addresses in that network are counted. + - For other types of hosts, each unique host is counted as one. + """ + num_hosts = 0 + for host in self._hosts: + if is_ip(host): + num_hosts += host.num_addresses + else: + num_hosts += 1 + return num_hosts diff --git a/radixtarget/test/test_pyradix.py b/radixtarget/test/test_pyradix.py deleted file mode 100644 index a67ae7c..0000000 --- a/radixtarget/test/test_pyradix.py +++ /dev/null @@ -1,95 +0,0 @@ -import time -import pytest -import random -import logging -import ipaddress -from pathlib import Path - -log = logging.getLogger("radixtarget.test") - -cidr_list_path = Path(__file__).parent / "cidrs.txt" - -from radixtarget import RadixTarget - - -def test_radixtarget(): - rt = RadixTarget() - - for _ in range(2): - - # ipv4 - rt.insert("192.168.1.0/24") - assert rt.search("192.168.1.10") == ipaddress.ip_network("192.168.1.0/24") - assert rt.search("192.168.2.10") is None - rt.insert(ipaddress.ip_network("10.0.0.0/8")) - assert rt.search("10.255.255.255") == ipaddress.ip_network("10.0.0.0/8") - rt.insert(ipaddress.ip_network("172.16.12.1")) - assert rt.search("172.16.12.1") == ipaddress.ip_network("172.16.12.1/32") - rt.insert("8.8.8.0/24", "custom_data_8") - assert rt.search("8.8.8.8") == "custom_data_8" - - # ipv6 - rt.insert("dead::/64") - assert rt.search("dead::beef") == ipaddress.ip_network("dead::/64") - assert rt.search("dead:cafe::beef") == None - rt.insert("cafe::babe") - assert rt.search("cafe::babe") == ipaddress.ip_network("cafe::babe/128") - rt.insert("beef::/120", "custom_beef") - assert rt.search("beef::bb") == "custom_beef" - - # networks - rt.insert("192.168.128.0/24") - assert rt.search("192.168.128.0/28") == ipaddress.ip_network("192.168.128.0/24") - assert rt.search("192.168.128.0/23") == None - rt.insert("babe::/64") - assert rt.search("babe::/96") == ipaddress.ip_network("babe::/64") - assert rt.search("babe::/63") == None - - # ipv4 / ipv6 confusion - rand_int = random.randint(0, 2**32 - 1) - ipv4_address = ipaddress.IPv4Address(rand_int) - ipv6_address = ipaddress.IPv6Address(rand_int << (128 - 32)) - ipv6_network = ipaddress.IPv6Network(f"{ipv6_address}/32") - rt.insert(ipv4_address) - assert rt.search(ipv4_address) - assert not rt.search(ipv6_address) - assert not rt.search(ipv6_network) - - # dns - rt.insert("net") - rt.insert("www.example.com") - rt.insert("test.www.example.com") - assert rt.search("net") == "net" - assert rt.search("evilcorp.net") == "net" - assert rt.search("www.example.com") == "www.example.com" - assert rt.search("asdf.test.www.example.com") == "test.www.example.com" - assert rt.search("example.com") is None - rt.insert("evilcorp.co.uk", "custom_data") - assert rt.search("www.evilcorp.co.uk") == "custom_data" - - with pytest.raises(ValueError, match=".*must be of str or ipaddress type.*"): - rt.insert(b"asdf") - - with pytest.raises(ValueError, match=".*must be of str or ipaddress type.*"): - rt.search(b"asdf") - - assert "net" in rt.dns_tree.root.children - assert "com" in rt.dns_tree.root.children - - # speed benchmark - cidrs = open(cidr_list_path).read().splitlines() - log.critical(len(cidrs)) - for c in cidrs: - rt.insert(c) - - iterations = 10000 - - start = time.time() - for i in range(iterations): - random_ip = ipaddress.ip_address(random.randint(0, 2**32 - 1)) - rt.search(random_ip) - end = time.time() - elapsed = end - start - log.critical( - f"{iterations:,} iterations in {elapsed:.4f} seconds ({int(iterations/elapsed)}/s)" - ) diff --git a/radixtarget/test/test_radixtarget.py b/radixtarget/test/test_radixtarget.py new file mode 100644 index 0000000..aa9ea32 --- /dev/null +++ b/radixtarget/test/test_radixtarget.py @@ -0,0 +1,337 @@ +import time +import pytest +import random +import logging +import ipaddress +from pathlib import Path + +from radixtarget.tree.ip import IPRadixTree +from radixtarget.tree.dns import DNSRadixTree + +log = logging.getLogger("radixtarget.test") + +cidr_list_path = Path(__file__).parent / "cidrs.txt" + +from radixtarget import Target + + +def test_radixtarget(): + """ + Tests various functionalities of the Target library, including: + + - Initialization and comparison of Target objects with different IPs and domains. + - Checking membership of IPs, subnets, and domains within Target objects. + - Hashing and equality checks for Target objects. + - Sorting of network and domain events based on size. + - Insertion and search operations in IP and DNS radix trees. + - Handling of strict DNS scope and error-raising during searches. + - Performance benchmarking for IP search operations. + - Ensuring correct handling of ACL mode for subnets and domains. + - Verifying strict DNS scope doesn't interfere with ACL operations. + """ + + target1 = Target("api.publicapis.org", "8.8.8.8/30", "2001:4860:4860::8888/126") + target2 = Target("8.8.8.8/29", "publicapis.org", "2001:4860:4860::8888/125") + target3 = Target("8.8.8.8/29", "publicapis.org", "2001:4860:4860::8888/125") + target4 = Target("8.8.8.8/29") + target5 = Target() + assert not target5 + assert len(target1) == 9 + assert len(target4) == 8 + assert "8.8.8.9" in target1 + assert "8.8.8.12" not in target1 + assert "8.8.8.8/31" in target1 + assert "8.8.8.8/30" in target1 + assert "8.8.8.8/29" not in target1 + assert "2001:4860:4860::8889" in target1 + assert "2001:4860:4860::888c" not in target1 + assert "www.api.publicapis.org" in target1 + assert "api.publicapis.org" in target1 + assert "publicapis.org" not in target1 + assert target1 in target2 + assert target2 not in target1 + assert target3 in target2 + assert target2 == target3 + assert target4 != target1 + + assert ipaddress.ip_network("8.8.8.8/30") in target1 + assert ipaddress.ip_network("8.8.8.8/30") in target1.hosts + + assert not target5 + assert len(target1) == 9 + assert len(target4) == 8 + assert "8.8.8.9" in target1 + assert "8.8.8.12" not in target1 + assert "8.8.8.8/31" in target1 + assert "8.8.8.8/30" in target1 + assert "8.8.8.8/29" not in target1 + assert "2001:4860:4860::8889" in target1 + assert "2001:4860:4860::888c" not in target1 + assert "www.api.publicapis.org" in target1 + assert "api.publicapis.org" in target1 + assert "publicapis.org" not in target1 + + assert str(target1.get("8.8.8.9")) == "8.8.8.8/30" + assert target1.get("8.8.8.12") is None + assert str(target1.get("2001:4860:4860::8889")) == "2001:4860:4860::8888/126" + assert target1.get("2001:4860:4860::888c") is None + assert str(target1.get("www.api.publicapis.org")) == "api.publicapis.org" + assert target1.get("publicapis.org") is None + + target = Target("evilcorp.com") + assert not "com" in target + assert "evilcorp.com" in target + assert "www.evilcorp.com" in target + strict_target = Target("evilcorp.com", strict_dns_scope=True) + assert not "com" in strict_target + assert "evilcorp.com" in strict_target + assert not "www.evilcorp.com" in strict_target + + target = Target() + target.add("evilcorp.com") + assert not "com" in target + assert "evilcorp.com" in target + assert "www.evilcorp.com" in target + strict_target = Target(strict_dns_scope=True) + strict_target.add("evilcorp.com") + assert not "com" in strict_target + assert "evilcorp.com" in strict_target + assert not "www.evilcorp.com" in strict_target + + # test target hashing + + target1 = Target() + target1.add("evilcorp.com") + target1.add("1.2.3.4/24") + target1.add("evilcorp.net") + assert ( + target1.hash == b"\xf7N\x89-\x7f(\xb3\xbe\n\xb9\xc5\xc3\x96\xee;\xecJ\xeb\xa8u" + ) + + target2 = Target() + target2.add("evilcorp.org") + target2.add("evilcorp.com") + target2.add("1.2.3.4/24") + target2.add("evilcorp.net") + assert ( + target2.hash == b"\xbe\xcf\xf3\x06\xcb`\xc9\xd17\x14\x1c\r\xc18\x95{4\xcb9\x8a" + ) + + target3 = Target(*list(target1)) + assert ( + target3.hash == b"\xf7N\x89-\x7f(\xb3\xbe\n\xb9\xc5\xc3\x96\xee;\xecJ\xeb\xa8u" + ) + + target4 = Target(*list(target1), strict_dns_scope=True) + assert target4.hash == b"stC\xd6\xd7\xa7\xf8\xfc\\4\xbd\x81NT\x17\xc6Nn'B" + + # make sure it's a sha1 hash + assert isinstance(target1.hash, bytes) + assert len(target1.hash) == 20 + + # hashes shouldn't match yet + assert target1.hash != target2.hash + # add missing host + target1.add("evilcorp.org") + # now they should match + assert target1.hash == target2.hash + + # test target sorting + from radixtarget.helpers import host_size_key + + big_subnet = "1.2.3.4/24" + medium_subnet = "1.2.3.4/28" + small_subnet = "1.2.3.4/30" + ip_event = "1.2.3.4" + parent_domain = "evilcorp.com" + grandparent_domain = "www.evilcorp.com" + greatgrandparent_domain = "api.www.evilcorp.com" + target = Target() + assert host_size_key(big_subnet) == (-256, "1.2.3.0/24") + assert host_size_key(medium_subnet) == (-16, "1.2.3.0/28") + assert host_size_key(small_subnet) == (-4, "1.2.3.4/30") + assert host_size_key(ip_event) == (-1, "1.2.3.4/32") + assert host_size_key(parent_domain) == (12, "evilcorp.com") + assert host_size_key(grandparent_domain) == (16, "www.evilcorp.com") + assert host_size_key(greatgrandparent_domain) == (20, "api.www.evilcorp.com") + events = [ + big_subnet, + medium_subnet, + small_subnet, + ip_event, + parent_domain, + grandparent_domain, + greatgrandparent_domain, + ] + random.shuffle(events) + assert sorted(events, key=host_size_key) == [ + big_subnet, + medium_subnet, + small_subnet, + ip_event, + parent_domain, + grandparent_domain, + greatgrandparent_domain, + ] + + # merging targets + target1 = Target("1.2.3.4/24", "evilcorp.net") + target2 = Target("evilcorp.com", "evilcorp.net") + assert sorted([str(h) for h in target1]) == ["1.2.3.0/24", "evilcorp.net"] + assert sorted([str(h) for h in target2]) == ["evilcorp.com", "evilcorp.net"] + target1.add(target2) + assert sorted([str(h) for h in target1]) == [ + "1.2.3.0/24", + "evilcorp.com", + "evilcorp.net", + ] + + # copying + target3 = target1.copy() + assert target3 == target1 + assert target3 is not target1 + assert sorted([str(h) for h in target3]) == [ + "1.2.3.0/24", + "evilcorp.com", + "evilcorp.net", + ] + + rt = Target() + + for _ in range(2): + + # ipv4 + rt.insert("192.168.1.0/24") + assert rt.search("192.168.1.10") == ipaddress.ip_network("192.168.1.0/24") + assert rt.search("192.168.2.10") is None + rt.insert(ipaddress.ip_network("10.0.0.0/8")) + assert rt.search("10.255.255.255") == ipaddress.ip_network("10.0.0.0/8") + rt.insert(ipaddress.ip_network("172.16.12.1")) + assert rt.search("172.16.12.1") == ipaddress.ip_network("172.16.12.1/32") + rt.insert("8.8.8.0/24", "custom_data_8") + assert rt.search("8.8.8.8") == "custom_data_8" + + # ipv6 + rt.insert("dead::/64") + assert rt.search("dead::beef") == ipaddress.ip_network("dead::/64") + assert rt.search("dead:cafe::beef") == None + rt.insert("cafe::babe") + assert rt.search("cafe::babe") == ipaddress.ip_network("cafe::babe/128") + rt.insert("beef::/120", "custom_beef") + assert rt.search("beef::bb") == "custom_beef" + + # networks + rt.insert("192.168.128.0/24") + assert rt.search("192.168.128.0/28") == ipaddress.ip_network("192.168.128.0/24") + assert rt.search("192.168.128.0/23") == None + rt.insert("babe::/64") + assert rt.search("babe::/96") == ipaddress.ip_network("babe::/64") + assert rt.search("babe::/63") == None + + # ipv4 / ipv6 confusion + rand_int = random.randint(0, 2**32 - 1) + ipv4_address = ipaddress.IPv4Address(rand_int) + ipv6_address = ipaddress.IPv6Address(rand_int << (128 - 32)) + ipv6_network = ipaddress.IPv6Network(f"{ipv6_address}/32") + rt.insert(ipv4_address) + assert rt.search(ipv4_address) + assert not rt.search(ipv6_address) + assert not rt.search(ipv6_network) + + # dns + rt.insert("net") + rt.insert("www.example.com") + rt.insert("test.www.example.com") + assert rt.search("net") == "net" + assert rt.search("evilcorp.net") == "net" + assert rt.search("www.example.com") == "www.example.com" + assert rt.search("asdf.test.www.example.com") == "test.www.example.com" + assert rt.search("example.com") is None + rt.insert("evilcorp.co.uk", "custom_data") + assert rt.search("www.evilcorp.co.uk") == "custom_data" + + with pytest.raises(ValueError, match=".*must be of str or ipaddress type.*"): + rt.insert(b"asdf") + + with pytest.raises(ValueError, match=".*must be of str or ipaddress type.*"): + rt.search(b"asdf") + + assert "net" in rt.dns_tree.root.children + assert "com" in rt.dns_tree.root.children + + # Tests for strict_scope parameter + dns_rt_strict_scope = DNSRadixTree(strict_scope=True) + dns_rt_strict_scope.insert("example.com") + assert dns_rt_strict_scope.search("example.com") == "example.com" + assert dns_rt_strict_scope.search("com") is None + assert dns_rt_strict_scope.search("www.example.com") is None + assert dns_rt_strict_scope.search("nonexistent.com") is None + assert ( + dns_rt_strict_scope.search("test.www.example.com", raise_error=False) + is None + ) + with pytest.raises(KeyError): + dns_rt_strict_scope.search("test.www.example.com", raise_error=True) + + # Tests for raise_error parameter + dns_rt = DNSRadixTree() + dns_rt.insert("example.com") + assert dns_rt.search("example.com") == "example.com" + assert dns_rt.search("nonexistent.com") is None + assert dns_rt.search("nonexistent.com", raise_error=False) is None + with pytest.raises(KeyError): + dns_rt.search("nonexistent.com", raise_error=True) + ip_rt = IPRadixTree() + ip_rt.insert("192.168.0.0/16") + assert ip_rt.search("192.168.1.1") == ipaddress.ip_network("192.168.0.0/16") + assert ip_rt.search("10.0.0.1") is None + assert ip_rt.search("10.0.0.1", raise_error=False) is None + with pytest.raises(KeyError): + ip_rt.search("10.0.0.1", raise_error=True) + + # speed benchmark + cidrs = open(cidr_list_path).read().splitlines() + log.critical(len(cidrs)) + for c in cidrs: + rt.insert(c) + + iterations = 10000 + + start = time.time() + for i in range(iterations): + random_ip = ipaddress.ip_address(random.randint(0, 2**32 - 1)) + rt.search(random_ip) + end = time.time() + elapsed = end - start + log.critical( + f"{iterations:,} iterations in {elapsed:.4f} seconds ({int(iterations/elapsed)}/s)" + ) + + # make sure child subnets/IPs don't get added to whitelist/blacklist + target = Target("1.2.3.4/24", "1.2.3.4/28", acl_mode=True) + assert sorted([str(h) for h in target.hosts]) == ["1.2.3.0/24"] + target = Target("1.2.3.4/28", "1.2.3.4/24", acl_mode=True) + assert sorted([str(h) for h in target.hosts]) == ["1.2.3.0/24"] + target = Target("1.2.3.4/28", "1.2.3.4", acl_mode=True) + assert sorted([str(h) for h in target.hosts]) == ["1.2.3.0/28"] + target = Target("1.2.3.4", "1.2.3.4/28", acl_mode=True) + assert sorted([str(h) for h in target.hosts]) == ["1.2.3.0/28"] + + # same but for domains + target = Target("evilcorp.com", "www.evilcorp.com", acl_mode=True) + assert sorted([str(h) for h in target.hosts]) == ["evilcorp.com"] + target = Target("www.evilcorp.com", "evilcorp.com", acl_mode=True) + assert sorted([str(h) for h in target.hosts]) == ["evilcorp.com"] + + # make sure strict_scope doesn't mess us up + target = Target( + "evilcorp.co.uk", "www.evilcorp.co.uk", acl_mode=True, strict_dns_scope=True + ) + assert sorted([str(h) for h in target.hosts]) == [ + "evilcorp.co.uk", + "www.evilcorp.co.uk", + ] + assert "evilcorp.co.uk" in target + assert "www.evilcorp.co.uk" in target + assert not "api.evilcorp.co.uk" in target + assert not "api.www.evilcorp.co.uk" in target diff --git a/radixtarget/tree/base.py b/radixtarget/tree/base.py index 1e72fd1..ff0e6d1 100644 --- a/radixtarget/tree/base.py +++ b/radixtarget/tree/base.py @@ -1,10 +1,13 @@ +sentinel = object() + + class RadixTreeNode: __slots__ = ("children", "host", "data") def __init__(self): self.children = {} self.host = None - self.data = None + self.data = sentinel class BaseRadixTree: diff --git a/radixtarget/tree/dns.py b/radixtarget/tree/dns.py index be19baa..07e5e70 100644 --- a/radixtarget/tree/dns.py +++ b/radixtarget/tree/dns.py @@ -1,8 +1,23 @@ -from .base import BaseRadixTree, RadixTreeNode +from .base import BaseRadixTree, RadixTreeNode, sentinel class DNSRadixTree(BaseRadixTree): + """A radix tree for efficient DNS hostname lookups. + + This tree stores hostnames in reverse order (TLD to subdomain) for hierarchical matching. + """ + + def __init__(self, strict_scope=False): + super().__init__() + self.strict_scope = strict_scope + def insert(self, hostname, data=None): + """Add a hostname to the tree. + + Args: + hostname (str): The hostname to insert. + data: Optional data to associate with the hostname. Defaults to the hostname itself. + """ if data is None: data = hostname parts = hostname.split(".") @@ -14,16 +29,30 @@ def insert(self, hostname, data=None): node = node.children[part] node.data = data - def search(self, hostname): + def search(self, hostname, raise_error=False): + """Find the most specific matching entry for a given hostname. + + Args: + hostname (str): The hostname to search for. + + Returns: + The data associated with the most specific matching hostname, or None if no match. + """ parts = hostname.split(".") node = self.root - matched_data = None + matched_data = sentinel # Search through the tree in the order from TLD to subdomain - for part in reversed(parts): + for i, part in enumerate(reversed(parts)): if part in node.children: node = node.children[part] - if node.data: - matched_data = node.data + # if strict scope is not enabled, every part must match + if self.strict_scope and i + 1 < len(parts): + continue + matched_data = node.data else: break + if matched_data is sentinel: + if raise_error: + raise KeyError(f'Hostname "{hostname}" not found') + return None return matched_data diff --git a/radixtarget/tree/ip.py b/radixtarget/tree/ip.py index 3c86512..dc80edf 100644 --- a/radixtarget/tree/ip.py +++ b/radixtarget/tree/ip.py @@ -2,10 +2,19 @@ from .base import BaseRadixTree, RadixTreeNode +sentinel = object() + class IPRadixTree(BaseRadixTree): + """A radix tree for efficient IP network lookups.""" def insert(self, network, data=None): + """Add an IP network to the tree. + + Args: + network: IP network to insert (string or ipaddress.IPv4Network/IPv6Network). + data: Optional data to associate with the network. Defaults to the network itself. + """ network = ipaddress.ip_network(network, strict=False) if data is None: data = network @@ -19,13 +28,25 @@ def insert(self, network, data=None): node.host = network node.data = data - def search(self, query): + def search(self, query, raise_error=False): + """Find the most specific matching entry for a given IP address or network. + + Args: + query: IP address or network to search for (string or ipaddress object). + raise_error: If True, raise KeyError when no match is found. Defaults to False. + + Returns: + The data associated with the most specific matching network, or None if no match. + + Raises: + KeyError: If raise_error is True and no match is found. + """ query_network = ipaddress.ip_network(query, strict=False) ip_value = int(query_network.network_address) query_prefixlen = query_network.prefixlen node = self.root - matched_data = None + matched_data = sentinel for i in range(query_prefixlen): current_bit = (ip_value >> (query_network.max_prefixlen - 1 - i)) & 1 if current_bit in node.children: @@ -35,4 +56,9 @@ def search(self, query): matched_data = node.data else: break + + if matched_data is sentinel: + if raise_error: + raise KeyError(f'IP "{query}" not found') + return None return matched_data