Skip to content

Commit

Permalink
better validation
Browse files Browse the repository at this point in the history
  • Loading branch information
TheTechromancer committed Nov 1, 2024
1 parent 0558041 commit b83071c
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 3 deletions.
21 changes: 21 additions & 0 deletions radixtarget/helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import ipaddress


Expand Down Expand Up @@ -28,6 +29,26 @@ def is_ip(host):
return ipaddress._IPAddressBase in host.__class__.__mro__


dns_name_regex = re.compile(r"^[\w]+[\w.-]*$", re.I)


def is_dns_name(host):
"""
Check if the given host is a valid DNS name.
This function uses a regular expression to determine if the provided
host string matches the pattern of a valid DNS name. The pattern allows
alphanumeric characters, underscores, hyphens, and periods, and is case-insensitive.
Args:
host (str): The host string to check.
Returns:
bool: True if the host is a valid DNS name, False otherwise.
"""
return bool(dns_name_regex.match(host))


def make_ip(host):
"""Convert a host to an IP network or return it as a lowercase string.
Expand Down
10 changes: 7 additions & 3 deletions radixtarget/radixtarget.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from radixtarget.tree.ip import IPRadixTree
from radixtarget.tree.dns import DNSRadixTree
from radixtarget.helpers import is_ip, make_ip, host_size_key
from radixtarget.helpers import is_ip, is_dns_name, make_ip, host_size_key


sentinel = object()
Expand Down Expand Up @@ -66,8 +66,10 @@ def get(self, host, raise_error=False):
host = make_ip(host)
if is_ip(host):
return self.ip_tree.search(host, raise_error=raise_error)
else:
elif is_dns_name(host):
return self.dns_tree.search(host, raise_error=raise_error)
else:
raise ValueError(f"Invalid host: '{host}'")

def search(self, host, raise_error=False):
"""
Expand Down Expand Up @@ -128,8 +130,10 @@ def _add_host(self, host, data=None):
self._hosts.add(host)
if is_ip(host):
return self.ip_tree.insert(host, data=data)
else:
elif is_dns_name(host):
return self.dns_tree.insert(host, data=data)
else:
raise ValueError(f"Invalid host: '{host}'")

@property
def hosts(self):
Expand Down
11 changes: 11 additions & 0 deletions radixtarget/test/test_radixtarget.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,3 +349,14 @@ def test_radixtarget():
assert "www.evilcorp.co.uk" in target
assert not "api.evilcorp.co.uk" in target
assert not "api.www.evilcorp.co.uk" in target

# test with invalid inputs
with pytest.raises(ValueError, match=".*Invalid host: 'http://example.com'.*"):
target = RadixTarget("http://example.com")
target = RadixTarget("example.com")
with pytest.raises(ValueError, match=".*Invalid host: 'evilcorp.com:80'.*"):
target.get("evilcorp.com:80")
with pytest.raises(ValueError, match=".*Invalid host: 'www.evilcorp.com:80'.*"):
target.add("www.evilcorp.com:80")
with pytest.raises(ValueError, match=".*Invalid host: 'evilcorp.com:80'.*"):
"evilcorp.com:80" in target

0 comments on commit b83071c

Please sign in to comment.