diff --git a/src/__init__.py b/src/__init__.py index 90beb86..e69de29 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,25 +0,0 @@ -CONFIG_SCHEMA = { - "name": {"required": True, "type": "string"}, - "containers": {"required": True, "type": "list"}, - "users": {"required": True, "type": "list"}, - "identityFile": {"required": True, "type": "list"}, - "hosts": {"required": True, "type": "list"}, - "metrics": { - "required": True, - "type": "dict", - "schema": { - "percentage": { - "required": True, - "type": "dict", - "schema": { - "value": {"required": True, "type": "number", "min": 0, "max": 100}, - "trend": { - "type": "string", - "nullable": True, - "regex": "^(?i)(down|equal|up)$", - }, - }, - } - }, - }, -} diff --git a/src/ctf.py b/src/ctf.py new file mode 100644 index 0000000..e0a8ffa --- /dev/null +++ b/src/ctf.py @@ -0,0 +1,154 @@ + +import os +import sys +from typing import List +import yamale +import click +import pathlib + +from ipaddress import ip_network +from yamale import YamaleError +from yamale.validators import DefaultValidators + +sys.path.append(os.getcwd()) +from src.host import Host +from src.log_config import get_logger +from src.utils import Path + + +logger = get_logger("ctf_creator.ctf") + +class CTFCreator(): + def __init__(self, config: str, save_path: str) -> None: + self.config = self._get_config(config) + logger.info(f"Containers: {self.config.get('containers')}") + logger.info(f"Users: {self.config.get('users')}") + logger.info(f"Key: {self.config.get('key')}") + logger.info(f"Hosts: {self.config.get('hosts')}") + logger.info( + f"IP-Address Subnet-base: {self.config.get('subnet')} " + ) + + self.save_path = save_path + self.subnet = ip_network(self.config.get("subnet")) + + def _get_config(self, config: dict) -> dict: + try: + validators = DefaultValidators.copy() # This is a dictionary + validators[Path.tag] = Path + schema = yamale.make_schema(f'{pathlib.Path(__file__).parent.resolve()}/schema.yaml', validators=validators) + # Create a Data object + data = yamale.make_data(content=config) + # Validate data against the schema. Throws a ValueError if data is invalid. + yamale.validate(schema, data) + + logger.info("YAML file loaded successfully.") + + return data[0][0] + except YamaleError as e: + logger.error('Validation failed!\n') + for result in e.results: + logger.error("Error validating data '%s' with '%s'\n\t" % (result.data, result.schema)) + for error in result.errors: + logger.error('\t%s' % error) + exit(1) + + def _get_hosts(self) -> List: + hosts = [] + for host in self.config.get('hosts'): + host_object = Host(host=host, save_path=self.save_path) + hosts.append(host_object) + host_object.clean_up() + return hosts + + def _extract_ovpn_info(self, file_path): + """ + Extracts the host IP address, port number, and subnet from an OpenVPN configuration file. + + Args: + file_path (str): Path to the OpenVPN configuration file. + + Returns: + tuple: A tuple containing: + - host_ip_address (str): The IP address found after the 'remote' keyword. + - port_number (int): The port number found after the IP address on the 'remote' line. + """ + if not os.path.exists(file_path): + logger.error(f"File {file_path} does not exist.") + return None + + host_ip_address = None + port_number = None + + with open(file_path, "r") as file: + lines = file.readlines() + + for line in lines: + if line.startswith("remote "): + parts = line.split() + if len(parts) == 3: + host_ip_address = parts[1] + port_number = int(parts[2]) + + if host_ip_address and port_number: + return host_ip_address, port_number + else: + logger.error("Failed to extract all necessary information.") + return None + + def create_challenge(self): + logger.info("Set up hosts.") + self.hosts = self._get_hosts() + logger.info("Begin set up of challenge.") + + http_port = 40000 + openvpn_port = 50000 + challenge_counter = 1 + next_network = self.subnet + + # TODO Handle issue, when address for OpenVPN or HTTP Port is already in use. + for idx, user in enumerate(self.config.get("users")): + if os.path.exists(f"{self.save_path}/data/{user}"): + logger.info(f"OpenVPN data exists for the user: {user}") + logger.info(f"Data for the user: {user} will NOT be changed. Starting OVPN Docker container with existing data.") + ip, openvpn_port = self._extract_ovpn_info(f"{self.save_path}/data/{user}/client.ovpn") + + logger.info(f"Get host with IP {ip}") + host: Host = [d for d in self.hosts if str(d.ip) == ip][0] + + host.send_and_extract_tar(user=user) + host.start_openvpn(user, openvpn_port, http_port, next_network) + else: + logger.info( + f"For the user: {user}, an OpenVPN configuration file will be generated!" + ) + host: Host = self.hosts[idx % len(self.hosts)] + host.start_openvpn(user, openvpn_port + challenge_counter, http_port + challenge_counter, next_network) + + # TODO handle start containers + # host.start_containers(user, self.config.get("containers")) + + next_network = ip_network((int(next_network.network_address) + next_network.num_addresses), strict=False) + challenge_counter += 1 + + +@click.command() +@click.option( + "--config", + required=True, + help="The path to the .yaml configuration file for the CTF-Creator.", + type=click.File('r', encoding='utf8'), +) +@click.option( + "--save", + required=True, + help="The path where you want to save the user data for the CTF-Creator. E.g. /home/nick/ctf-creator", + type=click.Path(writable=True), +) +def main(config, save): + ctfcreator = CTFCreator(config=config.read(), save_path=save) + ctfcreator.create_challenge() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/docker_env.py b/src/docker_env.py new file mode 100644 index 0000000..85626be --- /dev/null +++ b/src/docker_env.py @@ -0,0 +1,397 @@ +from ipaddress import ip_address +import sys +import os +import time + +from subprocess import run, CalledProcessError +from docker import DockerClient, from_env +from docker.errors import NotFound, APIError, ImageNotFound +from docker.types import EndpointConfig, IPAMPool, IPAMConfig + +sys.path.append(os.getcwd()) +from src.log_config import get_logger + +logger = get_logger("ctf_creator.docker") + +class DownloadError(Exception): + """Custom exception for download errors.""" + + pass + +class RemoteLineNotFoundError(Exception): + """Custom exception raised when no 'remote' line is found in the OpenVPN configuration file.""" + + pass + +class Docker(): + def __init__(self, host: dict) -> None: + self.username = host.get("username") + self.ip = ip_address(host.get("ip")) + self.client = DockerClient(base_url=f"ssh://{self.username}@{self.ip}") + + def prune(self): + try: + self.client.containers.prune() + # Stop all containers and remove them. + for item in self.client.containers.list(ignore_removed=True): + try: + self.client.containers.prune() + item.stop() + item.remove(force=True) + except NotFound: + logger.error( + f"Container {item.id} not found. It might have been removed already. But it is ok." + ) + # Delete all docker networks + self.client.networks.prune() + except APIError as api_err: + raise RuntimeError( + f"An error occurred while fetching the Docker API version on host {self.ip}. " + f"This may indicate that Docker is not installed, not running, or not properly configured on the host. " + f"Please verify the following: " + f"1. Docker is installed and running on the host by using the command: 'docker info'. " + f"2. The Docker daemon is configured to accept remote connections if you are accessing it remotely. " + f"3. Network connectivity to the Docker daemon is not blocked by a firewall or network policy. " + f"4. You can run Docker commands on the remote host without needing `sudo` privileges, or if `sudo` is required, ensure proper permissions are set. " + f"5. Check Docker's listening ports with commands such as 'sudo netstat -lntp | grep dockerd'. " + f"6. Check the CTF-creators README.md for more instructions. " + f"Original error: {api_err.explanation}" + ) + except Exception as e: + raise RuntimeError( + f"An unexpected error occurred while connecting to Docker on host {self.ip}. " + f"This may indicate that Docker is not installed, not running, or not properly configured on the host. " + f"Please verify the following: " + f"1. Docker is installed and running on the host by using the command: 'docker info'. " + f"2. The Docker daemon is configured to accept remote connections if you are accessing it remotely. " + f"3. Network connectivity to the Docker daemon is not blocked by a firewall or network policy. " + f"4. You can run Docker commands on the remote host without needing `sudo` privileges, or if `sudo` is required, ensure proper permissions are set. " + f"5. Check Docker's listening ports with commands such as 'sudo netstat -lntp | grep dockerd'. " + f"6. Check the CTF-creators README.md for more instructions. " + f"Original error: {e}" + ) + + def create_container(self, network_name, name, image): + """ + Create a Docker container with a specific name, image, and static IP address. + + Args: + network_name (str): The name of the network to connect the container to. + name (str): The name of the container to create (must be unique). + image (str): The Docker image to use for the container. + + Returns: + docker.models.containers.Container: The created Docker container. + + Raises: + docker.errors.APIError: If an error occurs during the container creation process. + """ + try: + container = self.client.containers.run( + image, + detach=True, + name=name, + network=network_name, + networking_config={network_name: self.endpoint_config}, + ) + return container + except APIError as e: + logger.error(f"Error creating container: {e}") + raise + + def _curl_client_ovpn( + self, + user: str, + http_port: str, + save_path: str, + max_retries_counter=0, + max_retries=10, + ) -> None: + """ + Downloads a .conf version of the client OpenVPN configuration file from a specified URL with retry logic. + + Args: + user (str): Name of the user. + http_port (str): Name of the user. + save_path (str): Path to the directory where the file will be saved. + max_retries_counter (int, optional): Current retry attempt count. + max_retries (int, optional): Maximum number of retry attempts. + """ + save_directory = f"{save_path}/data/{user}" + url = f"http://{self.ip}:{http_port}/client.ovpn" + + try: + os.makedirs(save_directory, exist_ok=True) + command = f"curl -o {save_directory}/client.ovpn {url}" + + while max_retries_counter < max_retries: + try: + run(command, shell=True, check=True) + logger.info(f"File downloaded successfully to {save_directory}/client.ovpn") + return + except CalledProcessError: + max_retries_counter += 1 + time.sleep(3) + logger.info(f"Retrying... ({max_retries_counter}/{max_retries})") + except Exception as e: + logger.error(f"Unexpected error: {e}") + max_retries_counter += 1 + time.sleep(3) + logger.info(f"Retrying... ({max_retries_counter}/{max_retries})") + + logger.info(f"Download failed after {max_retries} retries.") + raise DownloadError("Max retries exceeded.") + + except Exception as e: + logger.error(f"Error: An unexpected error occurred - {e}") + raise DownloadError(f"An unexpected error occurred - {e}") + + def get_openvpn_config(self, user: str, http_port: int, save_path: str): + logger.info(f"Downloading OpenVPN configuration for {user}...") + container_name = f"{user}_openvpn" + + # Download the folder with data + try: + container = self.client.containers.get(container_name) + except NotFound: + logger.error(f"Error: Container {container_name} not found.") + exit(1) + except Exception as e: + logger.error(f"Error: Something is wrong with {container_name}. {e}") + exit(1) + + try: + logger.info("Executing command in container...") + _, _ = container.exec_run("./genclient.sh", detach=True) + # Delay to give time to run the command in the container + time.sleep(5) + self._curl_client_ovpn(user=user, http_port=str(http_port), save_path=save_path) + except Exception as e: + logger.error(f"Error: Unable to execute command in container. {e}") + exit(1) + + try: + container = self.client.containers.get(container_name) + local_save_path = f"{save_path}/data/{user}" + local_path_to_data = f"{save_path}/data/{user}/dockovpn_data.tar" + os.makedirs(local_save_path, exist_ok=True) + archive, stat = container.get_archive("/opt/Dockovpn_data") + # Save the archive to a local file + with open(local_path_to_data, "wb") as f: + for chunk in archive: + f.write(chunk) + logger.info(f"Container found: {container_name}") + logger.info("And the Dockovpn_data folder is saved on this system") + except NotFound: + logger.info(f"Error: Container {container_name} not found.") + exit(1) + except Exception as e: + logger.info(f"Error: Something is wrong with the saving of the ovpn_data!. {e}") + exit(1) + + def create_openvpn_server( + self, + host_address: str, + network_name: str, + user: str, + openvpn_port: int, + http_port: int, + mount_path: int, + ): + """ + Create an OpenVPN server container with specific configurations. + + Args: + client (docker.DockerClient): An instance of the Docker client. + network_name (str): The name of the network to connect the container to. + name (str): The base name of the OpenVPN server container. + + Returns: + docker.models.containers.Container: The created OpenVPN server container. + + Raises: + docker.errors.APIError: If an error occurs during the container creation process. + """ + self.endpoint_config = EndpointConfig( + version="1.44", ipv4_address=host_address + ) + try: + container = self.client.containers.run( + image="alekslitvinenk/openvpn", + detach=True, + name=f"{user}_openvpn", + network=network_name, + restart_policy={"Name": "always"}, + cap_add=["NET_ADMIN"], + ports={"1194/udp": str(openvpn_port), "8080/tcp": str(http_port)}, + environment={ + "HOST_ADDR": f"{str(self.ip)}", + }, + networking_config={network_name: self.endpoint_config}, + volumes=[f"{mount_path}:/opt/Dockovpn_data"] + ) + + return container + except APIError as e: + logger.error(f"Error creating container: {e}") + raise + + def create_network(self, name, subnet_, gateway_): + """ + Create a Docker network with specific IPAM configuration. + + Args: + name (str): The name of the network to create. + subnet_ (str): The subnet to use for the network (e.g., '192.168.1.0/24'). + gateway_ (str): The gateway to use for the network (e.g., '192.168.1.1'). + + Returns: + docker.models.networks.Network: The created Docker network. + + Raises: + docker.errors.APIError: If an error occurs during the network creation process. + """ + ipam_pool = IPAMPool(subnet=subnet_, gateway=gateway_) + ipam_config = IPAMConfig(pool_configs=[ipam_pool]) + + # Create the network with IPAM configuration + return self.client.networks.create( + name, driver="bridge", ipam=ipam_config, check_duplicate=True + ) + + def check_image_existence(image_name): + """Checks if a Docker image exists in a remote registry using the Docker SDK for Python. + Otherwise it checks if this image can be pulled. + + Args: + image_name (str): The name of the image to check. + + Returns: + bool: True if the image exists, False otherwise. + + Raises: + docker.errors.ImageNotFound: If the image is not found in the remote registry. + """ + + try: + local_client = from_env() + + # Attempt to inspect the image locally + local_client.images.get(f"{image_name}") + logger.info(f"Image {image_name} exists locally.") + return True + except ImageNotFound: + # If the image is not local, try to pull it + try: + logger.warning(f"Try to pull Image {image_name}. Could take some time.") + local_client.images.pull(f"{image_name}") + logger.warning(f"Image {image_name} pulled successfully.") + return True + except ImageNotFound: + raise ImageNotFound( + f"Error: Image {image_name} could not be pulled. Does this Docker Image exist?" + ) + + + # TODO + def modify_ovpn_file(file_path, new_port, new_route_ip): + """ + Modifies an OpenVPN configuration file to update the remote port and add specific route settings. + + Args: + file_path (str): Path to the OpenVPN configuration file. + new_port (int): New port number to replace in the 'remote' line. + new_route_ip (str): New IP address for the route setting. + """ + if not os.path.exists(file_path): + print(f"File {file_path} does not exist.") + return + + modified_lines = [] + + with open(file_path, "r") as file: + lines = file.readlines() + + remote_line_found = False + + for line in lines: + if line.startswith("remote "): + parts = line.split() + if len(parts) == 3: + parts[-1] = str(new_port) + line = " ".join(parts) + "\n" + remote_line_found = True + + modified_lines.append(line) + modified_lines.append("route-nopull\n") + modified_lines.append(f"route {new_route_ip}\n") + modified_lines.append('pull-filter ignore "redirect-gateway"\n') + modified_lines.append('pull-filter ignore "dhcp-option"\n') + modified_lines.append('pull-filter ignore "route"\n') + else: + modified_lines.append(line) + + if not remote_line_found: + raise RemoteLineNotFoundError("No 'remote' line found in the file.") + + with open(file_path, "w") as file: + file.writelines(modified_lines) + + def modify_ovpn_file_change_host(file_path, new_ip, new_port, username): + """ + Changes the IP address and port in the 'remote' line of an OpenVPN configuration file + only if the IP address or port is different from the current values. + + Args: + file_path (str): Path to the OpenVPN configuration file. + new_ip (str): New IP address to replace in the 'remote' line. + new_port (int): New port number to replace in the 'remote' line. + username (str): Username associated with the OpenVPN configuration file. + + Returns: + str: The username if the configuration was changed, otherwise None. + """ + if not os.path.exists(file_path): + print(f"File {file_path} does not exist.") + return None + + modified_lines = [] + + with open(file_path, "r") as file: + lines = file.readlines() + + remote_line_found = False + change_needed = False + + for line in lines: + if line.startswith("remote "): + parts = line.split() + if len(parts) == 3: + current_ip = parts[1] + current_port = parts[2] + + if current_ip != new_ip or current_port != str(new_port): + parts[1] = str(new_ip) + parts[2] = str(new_port) + line = " ".join(parts) + "\n" + change_needed = True + + remote_line_found = True + + modified_lines.append(line) + + if not remote_line_found: + raise RemoteLineNotFoundError("No 'remote' line found in the file.") + + if change_needed: + with open(file_path, "w") as file: + file.writelines(modified_lines) + logger.info( + f"IP address and port in the 'remote' line of {file_path} have been successfully modified." + ) + return username + else: + logger.info( + f"No change needed for {username}. The IP address and port are already correct." + ) + return None diff --git a/src/docker_functions.py b/src/docker_functions.py deleted file mode 100644 index 3871408..0000000 --- a/src/docker_functions.py +++ /dev/null @@ -1,285 +0,0 @@ -""" -This module provides functionalities for managing Docker containers and networks, -as well as setting up and configuring OpenVPN servers. - -The primary functionalities include: -1. Creating Docker containers with specific static IP addresses. -2. Setting up and configuring OpenVPN servers. -3. Generating OpenVPN configuration files for users. -4. Creating Docker networks with IPAM configurations. -5. Checking the existence of Docker images locally or in a remote registry. - -Functions: -- create_container(client, network_name, name, image, static_address): Creates a Docker container with a specified name, image, and static IP address. -- create_openvpn_server(client, network_name, name, static_address, counter, host_address): Creates an OpenVPN server container with specified configurations. -- create_openvpn_server_with_existing_data(client, network_name, name, static_address, port_number, host_address, remote_path_to_mount): Creates an OpenVPN server container using existing data with specified configurations. -- create_openvpn_config(client, user_name, counter, host_address, save_path, new_push_route): Generates an OpenVPN configuration for a specified user and saves it to a specified path. -- create_network(client, name, subnet_, gateway_): Creates a Docker network with a specified name, subnet, and gateway using IPAM configuration. -- check_image_existence(image_name): Checks if a Docker image exists locally or in a remote registry and attempts to pull it if not found locally. -""" - -import docker -import docker.errors -import docker.types -import os -import ovpn_helper_functions as ovpn_func -import time -import docker - - -def create_container(client, network_name, name, image, static_address): - """ - Create a Docker container with a specific name, image, and static IP address. - - Args: - client (docker.DockerClient): An instance of the Docker client. - network_name (str): The name of the network to connect the container to. - name (str): The name of the container to create (must be unique). - image (str): The Docker image to use for the container. - static_address (str): The static IPv4 address to assign to the container. - - Returns: - docker.models.containers.Container: The created Docker container. - - Raises: - docker.errors.APIError: If an error occurs during the container creation process. - """ - endpoint_config = docker.types.EndpointConfig( - version="1.44", ipv4_address=static_address - ) - try: - container = client.containers.run( - image, - detach=True, - name=name, - network=network_name, - networking_config={network_name: endpoint_config}, - ) - return container - except docker.errors.APIError as e: - print(f"Error creating container: {e}") - raise - - -def create_openvpn_server( - client: docker.DockerClient, - network_name, - name, - static_address, - counter, - host_address, -): - """ - Create an OpenVPN server container with specific configurations. - - Args: - client (docker.DockerClient): An instance of the Docker client. - network_name (str): The name of the network to connect the container to. - name (str): The base name of the OpenVPN server container. - static_address (str): The static IPv4 address to assign to the OpenVPN server. - counter (int): Counter for port assignment to ensure uniqueness. - host_address (str): The host address to be used in the OpenVPN configuration. - - Returns: - docker.models.containers.Container: The created OpenVPN server container. - - Raises: - docker.errors.APIError: If an error occurs during the container creation process. - """ - endpoint_config = docker.types.EndpointConfig( - version="1.44", ipv4_address=static_address - ) - try: - container = client.containers.run( - image="alekslitvinenk/openvpn", - detach=True, - name=f"{name}_openvpn", - network=network_name, - restart_policy={"Name": "always"}, - cap_add=["NET_ADMIN"], - ports={"1194/udp": (1194 + counter), "8080/tcp": (80 + counter)}, - environment={ - "HOST_ADDR": f"{host_address}", - }, - networking_config={network_name: endpoint_config}, - ) - return container - except docker.errors.APIError as e: - print(f"Error creating container: {e}") - raise - - -def create_openvpn_server_with_existing_data( - client, - network_name, - name, - static_address, - port_number, - host_address, - remote_path_to_mount, -): - """ - Create an OpenVPN server container with existing data mounted. - - Args: - client (docker.DockerClient): An instance of the Docker client. - network_name (str): The name of the network to connect the container to. - name (str): The base name of the OpenVPN server container. - static_address (str): The static IPv4 address to assign to the OpenVPN server. - counter (int): Counter for port assignment to ensure uniqueness. - host_address (str): The host address to be used in the OpenVPN configuration. - remote_path_to_mount (str): The path to the directory on the host to mount into the container. - - Returns: - docker.models.containers.Container: The created OpenVPN server container with existing data mounted. - - Raises: - docker.errors.APIError: If an error occurs during the container creation process. - """ - - endpoint_config = docker.types.EndpointConfig( - version="1.44", ipv4_address=static_address - ) - try: - container = client.containers.run( - image="alekslitvinenk/openvpn", - detach=True, - name=f"{name}_openvpn", - network=network_name, - restart_policy={"Name": "always"}, - cap_add=["NET_ADMIN"], - ports={"1194/udp": (port_number), "8080/tcp": (port_number)}, - environment={ - "HOST_ADDR": f"{host_address}", - }, - networking_config={network_name: endpoint_config}, - volumes=[f"{remote_path_to_mount}:/opt/Dockovpn_data"], - ) - return container - except docker.errors.APIError as e: - print(f"Error creating container: {e}") - raise - - -def create_openvpn_config( - client, user_name, counter, host_address, save_path, new_push_route -): - """ - Generate an OpenVPN configuration file for a specified user. - - Args: - client (docker.DockerClient): An instance of the Docker client. - user_name (str): The name of the user for whom to create the configuration. - counter (int): Counter value for constructing URLs or commands. - host_address (str): The address of the host for downloading files or executing commands. - save_path (str): The path to save the OpenVPN configuration files. - new_push_route (str): The new route to push to the OpenVPN client. - - Raises: - docker.errors.NotFound: If the OpenVPN container for the user is not found. - """ - container_name = f"{user_name}_openvpn" - print(f"Creating OpenVPN configuration for {user_name}...") - - # Download the folder with data - try: - container = client.containers.get(container_name) - except docker.errors.NotFound: - print(f"Error: Container {container_name} not found.") - exit(1) - except Exception as e: - print(f"Error: Something is wrong with {container_name}. {e}") - exit(1) - - try: - print("Executing command in container...") - exit_code, output = container.exec_run("./genclient.sh", detach=True) - # Delay to give time to run the command in the container - time.sleep(5) - ovpn_func.curl_client_ovpn_file_version( - host_address, user_name, counter, save_path - ) - - except Exception as e: - print(f"Error: Unable to execute command in container. {e}") - exit(1) - try: - container = client.containers.get(container_name) - local_save_path = f"{save_path}/data/{user_name}" - local_path_to_data = f"{save_path}/data/{user_name}/dockovpn_data.tar" - os.makedirs(local_save_path, exist_ok=True) - archive, stat = container.get_archive("/opt/Dockovpn_data") - # Save the archive to a local file - with open(local_path_to_data, "wb") as f: - for chunk in archive: - f.write(chunk) - print( - f"Container found: {container_name}", - "And the Dockovpn_data folder is saved on this system", - ) - except docker.errors.NotFound: - print(f"Error: Container {container_name} not found.") - exit(1) - except Exception as e: - print(f"Error: Something is wrong with the saving of the ovpn_data!. {e}") - exit(1) - - -def create_network(client, name, subnet_, gateway_): - """ - Create a Docker network with specific IPAM configuration. - - Args: - client (docker.DockerClient): An instance of the Docker client. - name (str): The name of the network to create. - subnet_ (str): The subnet to use for the network (e.g., '192.168.1.0/24'). - gateway_ (str): The gateway to use for the network (e.g., '192.168.1.1'). - - Returns: - docker.models.networks.Network: The created Docker network. - - Raises: - docker.errors.APIError: If an error occurs during the network creation process. - """ - ipam_pool = docker.types.IPAMPool(subnet=subnet_, gateway=gateway_) - ipam_config = docker.types.IPAMConfig(pool_configs=[ipam_pool]) - - # Create the network with IPAM configuration - return client.networks.create( - name, driver="bridge", ipam=ipam_config, check_duplicate=True - ) - - -def check_image_existence(image_name): - """Checks if a Docker image exists in a remote registry using the Docker SDK for Python. - Otherwise it checks if this image can be pulled. - - Args: - image_name (str): The name of the image to check. - - Returns: - bool: True if the image exists, False otherwise. - - Raises: - docker.errors.ImageNotFound: If the image is not found in the remote registry. - """ - - try: - local_client = docker.from_env() - # Split image name into name and version if a colon is present - - # Attempt to inspect the image locally - local_client.images.get(f"{image_name}") - print(f"Image {image_name} exists locally.") - return True - except docker.errors.ImageNotFound: - # If the image is not local, try to pull it - try: - print(f"Try to pull Image {image_name}. Could take some time.") - local_client.images.pull(f"{image_name}") - print(f"Image {image_name} pulled successfully.") - return True - except docker.errors.ImageNotFound: - raise docker.errors.ImageNotFound( - f"Error: Image {image_name} could not be pulled. Does this Docker Image exist?" - ) diff --git a/src/host.py b/src/host.py new file mode 100644 index 0000000..70089fb --- /dev/null +++ b/src/host.py @@ -0,0 +1,237 @@ +import sys +import os +from typing import List + +from ipaddress import IPv4Network, IPv6Network +from paramiko import SSHClient, AutoAddPolicy +from ipaddress import ip_address +from subprocess import run, PIPE, TimeoutExpired + +sys.path.append(os.getcwd()) +from src.log_config import get_logger +from src.docker_env import Docker + +logger = get_logger("ctf_creator.host") + +class Host(): + def __init__(self, host: dict, save_path: str) -> None: + self.username = host.get("username") + self.ip = ip_address(host.get("ip")) + logger.info(f"Check connection for {self.username}@{self.ip}") + + self.identify_path = host.get("identity_file") + if not os.path.isfile(self.identify_path): + raise FileNotFoundError(f"Identity file not found: {self.identify_path}.") + + self._check_reachability() + self._check_ssh() + self._add_ssh_identity() + + self.docker = Docker(host=host) + self.save_path = save_path + + def _check_reachability(self): + """ + Checks the reachability of a host using the ping command. + + Args: + host_ips (list of str): List of host IP addresses to check. + + Raises: + SystemExit: If any host is unreachable, prints the unreachable hosts and exits the program. + """ + try: + run( + ["ping", "-c", "1", str(self.ip)], + stdout=PIPE, + stderr=PIPE, + ) + except Exception as e: + logger.error(f"Error pinging host {self.ip}: {e}") + raise + + def _check_ssh(self): + """ + Checks SSH connectivity and credentials for a given host. Helper function. + + Args: + self.ip (str): Host information in the format 'user@host'. + + Returns: + bool: True if SSH connection is successful, False otherwise. + """ + try: + # Attempt to SSH into the host + result = run( + ["ssh", "-i", self.identify_path, f"{self.username}@{self.ip}", "exit"], capture_output=True, text=True, timeout=10 + ) + + if result.returncode != 0: + if "Permission denied" in result.stderr: + logger.error( + f"SSH connection to {self.ip} failed due to incorrect username or password." + ) + else: + logger.error(f"SSH connection to {self.ip} failed: {result.stderr}") + return False + return True + except TimeoutExpired: + logger.error(f"SSH connection to {self.ip} timed out.") + return False + except Exception as e: + logger.error(f"Error attempting SSH connection to {self.ip}: {e}") + return False + + def _execute_ssh_command(self, command): + """ + Executes a command on a remote host via SSH. + + Args: + command (str): The command to execute on the remote host. + + Returns: + tuple: A tuple containing the command's output and error messages. + """ + + # Create an SSH client + ssh = SSHClient() + + # Load SSH host keys + ssh.load_system_host_keys() + + # Add missing host keys + ssh.set_missing_host_key_policy(AutoAddPolicy()) + + try: + # Connect to the remote host using SSH agent for authentication + ssh.connect(str(self.ip), port=22, username=self.username) + # Execute the command + stdin, stdout, stderr = ssh.exec_command(command) + # Read the output and error streams + output = stdout.read().decode() + error = stderr.read().decode() + # Print the output and error (if any) + if output: + logger.info("Output:\n", output) + if error: + logger.error("Error:\n", error) + + return output, error + + except Exception as e: + logger.error(f"An error occurred: {e}") + return None, str(e) + finally: + # Close the SSH connection + ssh.close() + + def _add_ssh_identity(self): + commands = [ + f'eval "$(ssh-agent)" ', + f'ssh-add {self.identify_path}' + ] + + try: + # Run all commands in the list of commands + for command in commands: + result = run(command, shell=True, executable="/bin/bash") + if result.returncode != 0: + logger.error(f"Error executing command: {command}") + break + except Exception as e: + raise RuntimeError(f"An unexpected error occurred: {e}") + + def clean_up(self): + self.docker.prune() + self._execute_ssh_command( + f"sudo test -d /home/{self.username}/ctf-data/ && sudo rm -r /home/{self.username}/ctf-data/" + ) + logger.info("Clean up process on hosts finished!") + + def send_and_extract_tar(self, user: str) -> None: + """ + Sends a tar file to a remote host via SSH and extracts it. + + Raises: + PermissionError: If there is a permission issue on the remote host. + """ + tar_file_path = f"{self.save_path}/data/{user}/dockovpn_data.tar" + remote_path = f"/home/{self.username}/ctf-data/{user}/dock_vpn_data.tar" + + # Create an SSH client + ssh = SSHClient() + + # Load SSH host keys + ssh.load_system_host_keys() + + # Add missing host keys + ssh.set_missing_host_key_policy(AutoAddPolicy()) + + try: + # Connect to the remote host using SSH agent for authentication + logger.info(f"Connecting to {self.ip} as {self.username}...") + ssh.connect(str(self.ip), port=22, username=self.username) + + # Extract the remote directory path + remote_dir = os.path.dirname(f"/home/{self.username}/ctf-data/{user}/Dockovpn_data/") + + # Ensure the remote directory exists + logger.info(f"Ensuring the remote directory {remote_dir} exists...") + mkdir_command = f"mkdir -p {remote_dir}" + _, stdout, stderr = ssh.exec_command(mkdir_command) + mkdir_error = stderr.read().decode().strip() + if mkdir_error: + raise PermissionError( + f"Failed to create directory {remote_dir}: {mkdir_error}" + ) + stdout.channel.recv_exit_status() # Wait for the command to complete + + # Use SFTP to copy the tar file + logger.info(f"Copying {tar_file_path} to {self.ip}:{remote_path}...") + sftp = ssh.open_sftp() + sftp.put(tar_file_path, remote_path) + sftp.close() + + logger.info(f"File {tar_file_path} successfully sent to {self.ip}:{remote_path}") + + # Ensure correct permissions for the remote path and extract the tar file + logger.info(f"Extracting tar file {remote_path} on {self.ip}...") + extract_command = f"tar -xf {remote_path} -C {remote_dir}" + _, stdout, stderr = ssh.exec_command(extract_command) + stdout.channel.recv_exit_status() # Wait for the command to complete + + # Read the output and error streams + output = stdout.read().decode().strip() + error = stderr.read().decode().strip() + + # Print the output and error (if any) + if output: + logger.info("Output:\n", output) + if error: + logger.error("Error:\n", error) + raise PermissionError(f"Failed to extract tar file: {error}") + + except PermissionError as pe: + logger.error(f"A permission error occurred: {pe}") + finally: + # Close the SSH connection + logger.info("Closing the SSH connection...") + ssh.close() + + def start_openvpn(self, user: str, openvpn_port: int, http_port: int, subnet: IPv4Network|IPv6Network ): + self.docker.create_network(name=f"{self.username}_network", subnet_=str(subnet), gateway_=str(subnet.network_address + 1)) + + self.docker.create_openvpn_server( + host_address=str(subnet.network_address + 2), + network_name=f"{self.username}_network", + user=user, + openvpn_port=openvpn_port, + http_port=http_port, + mount_path=f"/home/{self.username}/ctf-data/{user}/Dockovpn_data/", + ) + + if not os.path.exists(f"{self.save_path}/data/{user}"): + self.docker.get_openvpn_config(user=user, http_port=http_port, save_path=self.save_path) + + def start_containers(self, user: str, containers: List) -> None: + pass \ No newline at end of file diff --git a/src/hosts_functions.py b/src/hosts_functions.py deleted file mode 100644 index c4d3c1e..0000000 --- a/src/hosts_functions.py +++ /dev/null @@ -1,294 +0,0 @@ -""" -This module provides functionalities for managing SSH connections and operations, -including host reachability checks via ping and SSH, executing commands over SSH, -and transferring and extracting tar files on remote hosts. - -Functions: -- check_host_reachability_with_ping(host_ips): Checks the reachability of hosts using ping. -- check_ssh_connection(host_info): Verifies SSH connectivity and credentials for a given host. -- check_host_reachability_with_SSH(host_infos): Checks SSH connectivity and credentials for a list of hosts. -- execute_ssh_command(user_host, command, remote_port=22): Executes a command on a remote host via SSH. -- send_and_extract_tar_via_ssh(tar_file_path, host_username, remote_host, remote_path, remote_port=22): Sends a tar file to a remote host and extracts it. -""" - -import subprocess -import sys -import time -import os -import paramiko - - -def check_host_reachability_with_ping(host_ips): - """ - Checks the reachability of a list of hosts using the ping command. - - Args: - host_ips (list of str): List of host IP addresses to check. - - Raises: - SystemExit: If any host is unreachable, prints the unreachable hosts and exits the program. - """ - unreachable_hosts = [] - - for host in host_ips: - try: - result = subprocess.run( - ["ping", "-c", "1", host], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - if result.returncode != 0: - unreachable_hosts.append(host) - except Exception as e: - print(f"Error pinging host {host}: {e}") - unreachable_hosts.append(host) - - if unreachable_hosts: - print("The following hosts are unreachable:") - for host in unreachable_hosts: - print(f"- {host}") - print( - "Please check if you are connected to the correct network to ensure connectivity with these hosts." - ) - sys.exit(1) - else: - print("All hosts are reachable with ping.") - - -def check_ssh_connection(host_info): - """ - Checks SSH connectivity and credentials for a given host. Helper function. - - Args: - host_info (str): Host information in the format 'user@host'. - - Returns: - bool: True if SSH connection is successful, False otherwise. - """ - try: - # Attempt to SSH into the host - result = subprocess.run( - ["ssh", host_info, "exit"], capture_output=True, text=True, timeout=10 - ) - - if result.returncode != 0: - if "Permission denied" in result.stderr: - print( - f"SSH connection to {host_info} failed due to incorrect username or password." - ) - else: - print(f"SSH connection to {host_info} failed: {result.stderr}") - return False - return True - except subprocess.TimeoutExpired: - print(f"SSH connection to {host_info} timed out.") - return False - except Exception as e: - print(f"Error attempting SSH connection to {host_info}: {e}") - return False - - -# Uses also the unsername so you can deduce if the host Ip is wrong or the username! -def check_host_reachability_with_SSH(host_infos): - """ - Checks SSH connectivity and credentials for a list of hosts. - - Args: - host_infos (list of str): List of host information in the format 'user@host'. - - Raises: - SystemExit: If any host has incorrect SSH credentials or is unreachable, prints the problematic hosts and exits the program. - """ - unreachable_hosts = [] - - for host_info in host_infos: - if not check_ssh_connection(host_info): - unreachable_hosts.append(host_info) - time.sleep(1) - - if unreachable_hosts: - print("The following hosts are unreachable or have incorrect SSH credentials:") - for host_info in unreachable_hosts: - print(f"- {host_info}") - print( - "Please check if you are connected to the correct network and using the correct SSH host-username." - ) - sys.exit(1) - else: - print("All SSH connections to hosts were successful.") - - -def execute_ssh_command(user_host, command, remote_port=22): - """ - Executes a command on a remote host via SSH. - - Args: - user_host (str): User and host information in the format 'user@host'. - command (str): The command to execute on the remote host. - remote_port (int, optional): The port to use for SSH (default is 22). - - Returns: - tuple: A tuple containing the command's output and error messages. - """ - # Parse the user and host from the user_host variable - username, remote_host = user_host.split("@") - - # Create an SSH client - ssh = paramiko.SSHClient() - - # Load SSH host keys - ssh.load_system_host_keys() - - # Add missing host keys - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - - try: - # Connect to the remote host using SSH agent for authentication - ssh.connect(remote_host, port=remote_port, username=username) - - # Execute the command - stdin, stdout, stderr = ssh.exec_command(command) - - # Read the output and error streams - output = stdout.read().decode() - error = stderr.read().decode() - - # Print the output and error (if any) - if output: - print("Output:\n", output) - if error: - print("Error:\n", error) - - return output, error - - except Exception as e: - print(f"An error occurred: {e}") - return None, str(e) - finally: - # Close the SSH connection - ssh.close() - - -def send_and_extract_tar_via_ssh( - tar_file_path, host_username, remote_host, remote_path, remote_port=22 -): - """ - Sends a tar file to a remote host via SSH and extracts it. - - Args: - tar_file_path (str): The path to the tar file on the local system. - host_username (str): The username to use for SSH. - remote_host (str): The remote host to connect to. - remote_path (str): The path on the remote host where the tar file will be placed and extracted. - remote_port (int, optional): The port to use for SSH (default is 22). - - Raises: - PermissionError: If there is a permission issue on the remote host. - """ - - # Create an SSH client - ssh = paramiko.SSHClient() - - # Load SSH host keys - ssh.load_system_host_keys() - - # Add missing host keys - ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) - - try: - # Connect to the remote host using SSH agent for authentication - print(f"Connecting to {remote_host} as {host_username}...") - ssh.connect(remote_host, port=remote_port, username=host_username) - - # Extract the remote directory path - remote_dir = os.path.dirname(remote_path) - - # Ensure the remote directory exists - print(f"Ensuring the remote directory {remote_dir} exists...") - mkdir_command = f"mkdir -p {remote_dir}" - stdin, stdout, stderr = ssh.exec_command(mkdir_command) - mkdir_error = stderr.read().decode().strip() - if mkdir_error: - raise PermissionError( - f"Failed to create directory {remote_dir}: {mkdir_error}" - ) - stdout.channel.recv_exit_status() # Wait for the command to complete - - # Use SFTP to copy the tar file - print(f"Copying {tar_file_path} to {remote_host}:{remote_path}...") - sftp = ssh.open_sftp() - sftp.put(tar_file_path, remote_path) - sftp.close() - - print(f"File {tar_file_path} successfully sent to {remote_host}:{remote_path}") - - # Ensure correct permissions for the remote path and extract the tar file - print(f"Extracting tar file {remote_path} on {remote_host}...") - extract_command = f"tar -xf {remote_path} -C {remote_dir}" - stdin, stdout, stderr = ssh.exec_command(extract_command) - stdout.channel.recv_exit_status() # Wait for the command to complete - - # Read the output and error streams - output = stdout.read().decode().strip() - error = stderr.read().decode().strip() - - # Print the output and error (if any) - if output: - print("Output:\n", output) - if error: - print("Error:\n", error) - raise PermissionError(f"Failed to extract tar file: {error}") - - except PermissionError as pe: - print(f"A permission error occurred: {pe}") - except Exception as e: - print(f"An error occurred: {e}") - finally: - # Close the SSH connection - print("Closing the SSH connection...") - ssh.close() - - -def extract_ovpn_info(file_path): - """ - Extracts the host IP address, port number, and subnet from an OpenVPN configuration file. - - Args: - file_path (str): Path to the OpenVPN configuration file. - - Returns: - tuple: A tuple containing: - - host_ip_address (str): The IP address found after the 'remote' keyword. - - port_number (int): The port number found after the IP address on the 'remote' line. - - subnet (str): The subnet extracted from the 'route' line (first three sections of the IP address). - """ - if not os.path.exists(file_path): - print(f"File {file_path} does not exist.") - return None - - host_ip_address = None - port_number = None - subnet = None - - with open(file_path, "r") as file: - lines = file.readlines() - - for line in lines: - if line.startswith("remote "): - parts = line.split() - if len(parts) == 3: - host_ip_address = parts[1] - port_number = int(parts[2]) - - if line.startswith("route "): - parts = line.split() - if len(parts) >= 2: - ip_parts = parts[1].split(".") - if len(ip_parts) >= 3: - subnet = f"{ip_parts[0]}.{ip_parts[1]}.{ip_parts[2]}" - - if host_ip_address and port_number and subnet: - return host_ip_address, port_number, subnet - else: - print("Failed to extract all necessary information.") - return None diff --git a/src/log_config.py b/src/log_config.py index 2889e94..906e5e0 100644 --- a/src/log_config.py +++ b/src/log_config.py @@ -29,6 +29,7 @@ log_colors=log_colors, ) + class CustomHandler(logging.StreamHandler): """ Handles the different styles of logging messages with respect to their level. @@ -49,6 +50,7 @@ def format(self, record) -> str: return simple_formatter.format(record) return detailed_formatter.format(record) + def get_logger(module_name: str = "base") -> logging.Logger: """ Creates or retrieves a logger for a specific module. @@ -74,4 +76,4 @@ def get_logger(module_name: str = "base") -> logging.Logger: logger.setLevel(logging.DEBUG if debug_enabled else logging.INFO) - return logger \ No newline at end of file + return logger diff --git a/src/ovpn_helper_functions.py b/src/ovpn_helper_functions.py deleted file mode 100644 index f3140b9..0000000 --- a/src/ovpn_helper_functions.py +++ /dev/null @@ -1,178 +0,0 @@ -""" -This module provides functionality for downloading and modifying OpenVPN configuration files. - -Functions: -- curl_client_ovpn_file_version(container, host_address, user_name, counter, save_path, max_retries_counter=0, max_retries=10): Downloads an OpenVPN configuration file from a remote server with retry logic. -- modify_ovpn_file(file_path, new_port, new_route_ip): Modifies an OpenVPN configuration file to update the remote port and add specific route settings. -- modify_ovpn_file_change_host(file_path, new_ip, new_port): Changes the IP address and port in the 'remote' line of an OpenVPN configuration file. -""" - -import subprocess -import os -import time -import logging - -class DownloadError(Exception): - """Custom exception for download errors.""" - - pass - - -def curl_client_ovpn_file_version( - host_address, - user_name, - counter, - save_path, - max_retries_counter=0, - max_retries=10, -): - """ - Downloads a .conf version of the client OpenVPN configuration file from a specified URL with retry logic. - - Args: - container: Docker container object (unused in this function but kept for consistency). - host_address (str): Address of the host serving the file. - user_name (str): Name of the user. - counter (int): Counter for port calculation. - save_path (str): Path to the directory where the file will be saved. - max_retries_counter (int, optional): Current retry attempt count. - max_retries (int, optional): Maximum number of retry attempts. - """ - save_directory = f"{save_path}/data/{user_name}" - url = f"http://{host_address}:{80 + counter}/client.ovpn" - - try: - os.makedirs(save_directory, exist_ok=True) - command = f"curl -o {save_directory}/client.ovpn {url}" - - while max_retries_counter < max_retries: - try: - subprocess.run(command, shell=True, check=True) - logging.info(f"File downloaded successfully to {save_directory}/client.ovpn") - return - except subprocess.CalledProcessError: - max_retries_counter += 1 - time.sleep(3) - logging.info(f"Retrying... ({max_retries_counter}/{max_retries})") - except Exception as e: - logging.error(f"Unexpected error: {e}") - max_retries_counter += 1 - time.sleep(3) - logging.info(f"Retrying... ({max_retries_counter}/{max_retries})") - - logging.info(f"Download failed after {max_retries} retries.") - raise DownloadError("Max retries exceeded.") - - except Exception as e: - logging.error(f"Error: An unexpected error occurred - {e}") - raise DownloadError(f"An unexpected error occurred - {e}") - - -class RemoteLineNotFoundError(Exception): - """Custom exception raised when no 'remote' line is found in the OpenVPN configuration file.""" - - pass - - -def modify_ovpn_file(file_path, new_port, new_route_ip): - """ - Modifies an OpenVPN configuration file to update the remote port and add specific route settings. - - Args: - file_path (str): Path to the OpenVPN configuration file. - new_port (int): New port number to replace in the 'remote' line. - new_route_ip (str): New IP address for the route setting. - """ - if not os.path.exists(file_path): - print(f"File {file_path} does not exist.") - return - - modified_lines = [] - - with open(file_path, "r") as file: - lines = file.readlines() - - remote_line_found = False - - for line in lines: - if line.startswith("remote "): - parts = line.split() - if len(parts) == 3: - parts[-1] = str(new_port) - line = " ".join(parts) + "\n" - remote_line_found = True - - modified_lines.append(line) - modified_lines.append("route-nopull\n") - modified_lines.append(f"route {new_route_ip}\n") - modified_lines.append('pull-filter ignore "redirect-gateway"\n') - modified_lines.append('pull-filter ignore "dhcp-option"\n') - modified_lines.append('pull-filter ignore "route"\n') - else: - modified_lines.append(line) - - if not remote_line_found: - raise RemoteLineNotFoundError("No 'remote' line found in the file.") - - with open(file_path, "w") as file: - file.writelines(modified_lines) - - -def modify_ovpn_file_change_host(file_path, new_ip, new_port, username): - """ - Changes the IP address and port in the 'remote' line of an OpenVPN configuration file - only if the IP address or port is different from the current values. - - Args: - file_path (str): Path to the OpenVPN configuration file. - new_ip (str): New IP address to replace in the 'remote' line. - new_port (int): New port number to replace in the 'remote' line. - username (str): Username associated with the OpenVPN configuration file. - - Returns: - str: The username if the configuration was changed, otherwise None. - """ - if not os.path.exists(file_path): - print(f"File {file_path} does not exist.") - return None - - modified_lines = [] - - with open(file_path, "r") as file: - lines = file.readlines() - - remote_line_found = False - change_needed = False - - for line in lines: - if line.startswith("remote "): - parts = line.split() - if len(parts) == 3: - current_ip = parts[1] - current_port = parts[2] - - if current_ip != new_ip or current_port != str(new_port): - parts[1] = str(new_ip) - parts[2] = str(new_port) - line = " ".join(parts) + "\n" - change_needed = True - - remote_line_found = True - - modified_lines.append(line) - - if not remote_line_found: - raise RemoteLineNotFoundError("No 'remote' line found in the file.") - - if change_needed: - with open(file_path, "w") as file: - file.writelines(modified_lines) - logging.info( - f"IP address and port in the 'remote' line of {file_path} have been successfully modified." - ) - return username - else: - logging.info( - f"No change needed for {username}. The IP address and port are already correct." - ) - return None diff --git a/src/pyyaml_functions.py b/src/pyyaml_functions.py deleted file mode 100644 index c02d625..0000000 --- a/src/pyyaml_functions.py +++ /dev/null @@ -1,219 +0,0 @@ -""" -This module provides core functionalities for reading and validating configuration data -from a YAML file and extracting relevant information for setting up Docker containers and -network configurations in the CTF (Capture The Flag) environment. - -Key Features: -1. Validation of YAML Configuration: Ensures that all required fields in the YAML configuration - are present, properly formatted, and contain valid values. -2. Docker Setup Support: Extracts and returns necessary data to set up Docker containers for users - based on the configuration. -3. Extracts host information and user names from formatted host strings. - -Functions: -- read_data_from_yaml(data): Validates and extracts configuration data from a provided dictionary. -- extract_hosts(hosts): Extracts the host IP addresses from each string in a list of hosts. -- find_host_username_by_ip(hosts, existing_host_ip): Finds and returns the username associated with a - given IP address from a list of host strings. - -""" - -import re -import docker_functions as doc_func - - -def read_data_from_yaml(data): - """ - Validates and extracts configuration data from a provided dictionary. - - This function ensures that all required fields are present, correctly formatted, - and contain valid values. It also converts specific subnet values to integers - and validates that the hosts are in the correct format. - - Args: - data (dict): Configuration data expected to contain the following keys: - - "containers": List of Docker container names. - - "users": List of user names. - - "identityFile": Path to the private SSH key used for host login. - - "hosts": List of host addresses where the Docker containers will run. - - "subnet_first_part": The first part of the subnet IP address as a list containing one string. - - "subnet_second_part": The second part of the subnet IP address as a list containing one string. - - "subnet_third_part": The third part of the subnet IP address as a list containing one string. - - Returns: - tuple: A tuple containing: - - containers (list): List of Docker containers to be started for each user. - - users (list): List of users. - - key (list): List containing the path to the private SSH key for host login. - - hosts (list): List of hosts where the Docker containers are running. - - subnet_first_part (int): First part of the subnet IP address. - - subnet_second_part (int): Second part of the subnet IP address. - - subnet_third_part (int): Third part of the subnet IP address. - - Raises: - ValueError: If any required fields are missing, not lists, contain invalid values, - or if host addresses do not match the 'username@ip_address' format. - """ - required_fields = [ - "containers", - "users", - "identityFile", - "hosts", - "subnet_first_part", - "subnet_second_part", - "subnet_third_part", - ] - - for field in required_fields: - if field not in data: - raise ValueError(f"Missing {field} field in YAML data") - - # Extract and return the relevant data - containers = data.get("containers", []) - users = data.get("users", []) - key = data.get("identityFile", []) - hosts = data.get("hosts", []) - subnet_first_part = data.get("subnet_first_part", []) - subnet_second_part = data.get("subnet_second_part", []) - subnet_third_part = data.get("subnet_third_part", []) - - # Ensure each entry is a list - for field_name, field_value in [ - ("containers", containers), - ("users", users), - ("identityFile", key), - ("hosts", hosts), - ("subnet_first_part", subnet_first_part), - ("subnet_second_part", subnet_second_part), - ("subnet_third_part", subnet_third_part), - ]: - if not isinstance(field_value, list): - raise ValueError(f"Expected '{field_name}' to be a list") - - # Ensure lists are not empty - if not containers: - raise ValueError("Expected 'containers' to be a non-empty list") - if not users: - raise ValueError("Expected 'users' to be a non-empty list") - if not hosts: - raise ValueError("Expected 'hosts' to be a non-empty list") - if not key: - raise ValueError("Expected 'identityFile' to be a non-empty list") - if not subnet_first_part: - raise ValueError("Expected 'subnet_first_part' to be a non-empty list") - if not subnet_second_part: - raise ValueError("Expected 'subnet_second_part' to be a non-empty list") - if not subnet_third_part: - raise ValueError("Expected 'subnet_third_part' to be a non-empty list") - - # Ensure subnet fields contain exactly one value and convert to integer - for subnet_field, subnet_value in [ - ("subnet_first_part", subnet_first_part), - ("subnet_second_part", subnet_second_part), - ("subnet_third_part", subnet_third_part), - ]: - if len(subnet_value) != 1: - raise ValueError(f"Expected '{subnet_field}' to contain exactly one value") - - try: - subnet_value[0] = int(subnet_value[0]) - except ValueError: - raise ValueError(f"Expected '{subnet_field}' to contain an integer value") - - # Ensure lists are not empty - if not containers: - raise ValueError("Expected 'containers' to be a non-empty list") - if not users: - raise ValueError("Expected 'users' to be a non-empty list") - if not hosts: - raise ValueError("Expected 'hosts' to be a non-empty list") - if not key: - raise ValueError("Expected 'identityFile' to be a non-empty list") - - # Ensure each host follows the 'username@ip_address' format - host_pattern = re.compile(r"^[\w._-]+@\d{1,3}(?:\.\d{1,3}){3}$") - for host in hosts: - if not host_pattern.match(host): - raise ValueError( - f"Expected 'hosts' entries to be in the format 'username@ip_address', but got '{host}'" - ) - - for container in containers: - doc_func.check_image_existence(container) - - # Convert singular values to lists if needed - containers = containers if isinstance(containers, list) else [containers] - users = users if isinstance(users, list) else [users] - hosts = hosts if isinstance(hosts, list) else [hosts] - - # Extract the singular values - subnet_first_part = subnet_first_part[0] - subnet_second_part = subnet_second_part[0] - subnet_third_part = subnet_third_part[0] - - return ( - containers, - users, - key, - hosts, - subnet_first_part, - subnet_second_part, - subnet_third_part, - ) - - -def extract_hosts(hosts): - """ - Extracts the host part from each string in a list of hosts. - - Args: - hosts (list): List of host strings, each containing an '@' symbol. - - Returns: - list: A list of extracted host parts. - - Raises: - ValueError: If a string does not contain exactly one '@' symbol or is empty. - """ - extracted_hosts = [] - for host in hosts: - if not host: - raise ValueError("Host string cannot be empty") - - parts = host.split("@") - if len(parts) != 2: - raise ValueError( - f"Host string must contain exactly one '@' symbol, but got: '{host}'" - ) - - extracted_hosts.append(parts[1]) - - return extracted_hosts - - -def find_host_username_by_ip(hosts, existing_host_ip): - """ - Finds the username associated with a given IP address from a list of host entries. - - Args: - hosts (list): A list of host entries in the format 'username@ipaddress'. - existing_host_ip (str): The IP address to find in the hosts list. - - Returns: - str: The username associated with the given IP address, or None if not found. - - """ - for host in hosts: - parts = host.split("@") - if len(parts) != 2: - raise ValueError( - f"Host string must contain exactly one '@' symbol, but got: '{host}'" - ) - username, ip_address = host.split("@") - if ip_address == existing_host_ip: - return username - - print( - f"Warning: The IP address {existing_host_ip} in the client.ovpn is not defined in the YAML configuration for the CTF-Creator." - ) - return None diff --git a/src/schema.yaml b/src/schema.yaml new file mode 100644 index 0000000..f19a01d --- /dev/null +++ b/src/schema.yaml @@ -0,0 +1,18 @@ +# ctf_main_schema.yaml +--- +name: str(required=True) # Name of the YAML config must be a string +containers: list(include('container'), required=True) # List of Docker containers +users: list(str(), required=True, unique=True) # List of users (should be unique) +hosts: list(include('host'), required=True) # List of hosts +subnet: ip(required=True) # subnet + +--- +host: + ip: str(required=True) + username: str(required=True) + identity_file: path(required=True) + +--- +container: + image: str(required=True) + enviroment: list(str(), required=False) \ No newline at end of file diff --git a/src/utils.py b/src/utils.py new file mode 100644 index 0000000..1fb8ba7 --- /dev/null +++ b/src/utils.py @@ -0,0 +1,10 @@ +import os + +from yamale.validators import DefaultValidators, Validator + +class Path(Validator): + """ Custom Path validator """ + tag = 'path' + + def _is_valid(self, value): + return os.path.isfile(value) \ No newline at end of file diff --git a/src/validation_functions.py b/src/validation_functions.py deleted file mode 100644 index 2fcf206..0000000 --- a/src/validation_functions.py +++ /dev/null @@ -1,63 +0,0 @@ -""" -This module provides functionality to validate file paths and ensure they meet specific criteria. - -The main functionality includes: - -1. Validating a directory path intended for saving user data. -2. Validating a configuration file path to ensure it is a valid YAML file. - -Functions: -- validate_save_path(ctx, param, value): Validates the provided save path, ensuring it is a valid directory path. If the path does not exist, it attempts to create it. Raises an error if the path is invalid or cannot be created. -- validate_yaml_file(ctx, param, value): Validates the provided configuration file path. Ensures the file exists, is a valid file (not a directory), has a .yaml or .yml extension, and contains valid YAML content. Raises an error if any of these conditions are not met. -""" - -import os -import click -import yaml - - -def validate_save_path(ctx, param, value): - """ - Validate the save_path provided by the user. - Ensure it's a valid path where a directory can be created. - """ - # Check if the path already exists - if os.path.exists(value): - # If it exists, ensure it's a directory - if not os.path.isdir(value): - raise click.BadParameter(f"The path '{value}' is not a directory.") - else: - # If it doesn't exist, check if the directory can be created - try: - os.makedirs(value, exist_ok=True) - except Exception as e: - raise click.BadParameter(f"The path '{value}' is invalid: {e}") - - return value - - -def validate_yaml_file(ctx, param, value): - """ - Validate the config file path provided by the user.s - Ensure it's a valid YAML file with the correct extension. - """ - # Check if the path exists - if not os.path.exists(value): - raise click.BadParameter(f"The file '{value}' does not exist.") - - # Check if it's a file (not a directory) - if not os.path.isfile(value): - raise click.BadParameter(f"The path '{value}' is not a file.") - - # Check if the file has a .yaml or .yml extension - if not (value.endswith(".yaml") or value.endswith(".yml")): - raise click.BadParameter("The file must have a .yaml or .yml extension.") - - # Try to load the file using a YAML parser to ensure it's a valid YAML - try: - with open(value, "r") as file: - yaml.safe_load(file) - except yaml.YAMLError as exc: - raise click.BadParameter(f"The file '{value}' is not a valid YAML file: {exc}") - - return value