From a03a0269fe8aa645ab2928f2e2e7fc115a8ead2a Mon Sep 17 00:00:00 2001 From: Stefan Machmeier Date: Wed, 8 May 2024 10:27:51 +0200 Subject: [PATCH] Include pre commit checks --- .github/workflows/build_test_linux.yml | 2 +- .github/workflows/build_test_macos.yml | 2 +- .github/workflows/build_test_windows.yml | 2 +- .gitignore | 2 +- .pre-commit-config.yaml | 12 +++ .readthedocs.yml | 2 +- Dockerfile | 2 +- README.md | 6 +- docker-compose.yml | 23 +++--- docs/requirements.txt | 2 +- docs/source/usage.rst | 1 - heidgaf/cli.py | 73 ++++++++++--------- heidgaf/detectors/arima_anomaly_detector.py | 3 +- heidgaf/detectors/exponential_thresholding.py | 3 +- heidgaf/detectors/real_time_anomaly.py | 3 +- heidgaf/detectors/thresholding_algorithm.py | 3 +- heidgaf/inspectors/__init__.py | 7 +- heidgaf/inspectors/domain_analyzer.py | 9 ++- heidgaf/inspectors/ip_analyzer.py | 4 +- heidgaf/main.py | 48 ++++++------ heidgaf/models/__init__.py | 23 +++--- requirements.txt | 2 +- tests.py | 1 - 23 files changed, 126 insertions(+), 109 deletions(-) create mode 100644 .pre-commit-config.yaml diff --git a/.github/workflows/build_test_linux.yml b/.github/workflows/build_test_linux.yml index cbd4f13..ccd7667 100644 --- a/.github/workflows/build_test_linux.yml +++ b/.github/workflows/build_test_linux.yml @@ -56,4 +56,4 @@ jobs: - name: Test if: startsWith(matrix.os, 'ubuntu') && !startsWith(matrix.python-version, '3.10') run: | - python -m pytest tests.py \ No newline at end of file + python -m pytest tests.py diff --git a/.github/workflows/build_test_macos.yml b/.github/workflows/build_test_macos.yml index 28d9d82..827d02d 100644 --- a/.github/workflows/build_test_macos.yml +++ b/.github/workflows/build_test_macos.yml @@ -57,4 +57,4 @@ jobs: # On other versions then 3.9, we test only. (without coverage generation) if: startsWith(matrix.os, 'macos') && !startsWith(matrix.python-version, '3.9') && !startsWith(github.ref, 'refs/tags/') run: | - python -m pytest tests.py \ No newline at end of file + python -m pytest tests.py diff --git a/.github/workflows/build_test_windows.yml b/.github/workflows/build_test_windows.yml index e488c5f..bcca3ee 100644 --- a/.github/workflows/build_test_windows.yml +++ b/.github/workflows/build_test_windows.yml @@ -50,4 +50,4 @@ jobs: # On other versions then 3.9, we test only. (without coverage generation) if: startsWith(matrix.os, 'windows') && !startsWith(matrix.python-version, '3.9') && !startsWith(github.ref, 'refs/tags/') run: | - python -m pytest tests.py \ No newline at end of file + python -m pytest tests.py diff --git a/.gitignore b/.gitignore index 0de7181..19a5c5f 100644 --- a/.gitignore +++ b/.gitignore @@ -318,4 +318,4 @@ dmypy.json # Cython debug symbols cython_debug/ -# End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks \ No newline at end of file +# End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..c174cac --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,12 @@ +repos: + - repo: https://github.com/psf/black + rev: 24.4.2 + hooks: + - id: black + language_version: python3.11 + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace diff --git a/.readthedocs.yml b/.readthedocs.yml index 602ff7e..b3b2f23 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -11,4 +11,4 @@ python: - requirements: requirements.txt sphinx: - configuration: docs/source/conf.py \ No newline at end of file + configuration: docs/source/conf.py diff --git a/Dockerfile b/Dockerfile index 3f7fb9c..0dd3c44 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,4 +8,4 @@ COPY heidgaf/ heidgaf/ RUN pip --disable-pip-version-check install --no-cache-dir --no-compile . -CMD [ "heidgaf", "-h" ] \ No newline at end of file +CMD [ "heidgaf", "-h" ] diff --git a/README.md b/README.md index 4e6040c..d4eea6e 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ Train your own model: ```sh heidgaf train -m xg -d all -``` +``` ### Data @@ -115,10 +115,10 @@ Based on the following work, we implement heiDGAF to find malicious behaviour su Propose a hybrid DNS tunneling detection system using Tabu-PIO for feature selection. - Classifying Malicious Domains using DNS Traffic Analysis - + - [DeepDGA](https://github.com/roreagan/DeepDGA): Adversarially-Tuned Domain Generation and Detection - + DeepDGA detecting (and generating) domains on a per-domain basis which provides a simple and flexible means to detect known DGA families. It uses GANs to bypass detectors and shows the effectiveness of such solutions. - Kitsune: An Ensemble of Autoencoders for Online Network Intrusion Detection diff --git a/docker-compose.yml b/docker-compose.yml index 550edf0..19197ec 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,10 +1,9 @@ -version: "3.9" - volumes: redis: heidgaf: services: + # TODO Currently not supported. # redis: # image: redis:latest # ports: @@ -19,10 +18,16 @@ services: command: ["heidgaf", "inspect", "-r", "/tmp/data", "-m", "xg"] volumes: - ./data/heicloud:/tmp/data - # deploy: - # resources: - # reservations: - # devices: - # - driver: nvidia - # count: all - # capabilities: [gpu] + memswap_limit: 42G + deploy: + resources: + limits: + cpus: '6' + memory: 32g + reservations: + cpus: '4' + memory: 24g + # devices: + # - driver: nvidia + # count: all + # capabilities: [gpu] diff --git a/docs/requirements.txt b/docs/requirements.txt index 3863c0b..7a2e87a 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,4 +4,4 @@ sphinxcontrib.apidoc==0.5.0 sphinx_autodoc_typehints==2.0.0 nbsphinx==0.9.3 myst_parser==2.0.0 -sphinx_design==0.5.0 \ No newline at end of file +sphinx_design==0.5.0 diff --git a/docs/source/usage.rst b/docs/source/usage.rst index fa69fd2..c24663d 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -11,4 +11,3 @@ To use heiDGAF, first install it using pip: .. code-block:: console (.venv) $ pip install . - diff --git a/heidgaf/cli.py b/heidgaf/cli.py index ac33092..4d3e59c 100644 --- a/heidgaf/cli.py +++ b/heidgaf/cli.py @@ -18,49 +18,47 @@ def cli(): click.echo("Starting heiDGAF CLI") + @cli.command(name="train", context_settings={"show_default": True}) @click.option( - "-m", - "--model", - "model", - required=True, + "-m", + "--model", + "model", + required=True, type=click.Choice(Model), - help="Model for fitting." + help="Model for fitting.", ) @click.option( - "-d", - "--dataset", - "dataset", - required=True, + "-d", + "--dataset", + "dataset", + required=True, type=click.Choice(Dataset), default=Dataset.ALL, - help="Dataset for fitting." + help="Dataset for fitting.", ) @click.option( - "-o", - "--output_dir", - "output_dir", - required=True, + "-o", + "--output_dir", + "output_dir", + required=True, type=click.STRING, - help="Output path of model." + help="Output path of model.", ) def train(model, dataset, output_dir): click.echo("Start training of model.") - trainer = DNSAnalyzerTraining( - model=model, - dataset=dataset - ) + trainer = DNSAnalyzerTraining(model=model, dataset=dataset) trainer.train(output_path=output_dir) @cli.command(name="inspect", context_settings={"show_default": True}) @click.option( - "-r", - "--read", - "input_dir", - required=True, - type=click.Path(), - help="Input directory or file for analyzing." + "-r", + "--read", + "input_dir", + required=True, + type=click.Path(), + help="Input directory or file for analyzing.", ) @click.option( "-dt", @@ -71,12 +69,12 @@ def train(model, dataset, output_dir): help="Sets the anomaly detector.", ) @click.option( - "-m", - "--model", - "model", - required=True, + "-m", + "--model", + "model", + required=True, type=click.Choice(Model), - help="Model for prediction." + help="Model for prediction.", ) @click.option( "-s", @@ -143,7 +141,18 @@ def train(model, dataset, output_dir): help="Sets Redis max connection for caching results.", ) def inspection( - input_dir, detector, model, separator, filetype, lag, influence, n_standard_deviations, redis_host, redis_port, redis_db, redis_max_connection + input_dir, + detector, + model, + separator, + filetype, + lag, + influence, + n_standard_deviations, + redis_host, + redis_port, + redis_db, + redis_max_connection, ): click.echo("Starts processing log lines of DNS traffic.") pipeline = DNSInspectorPipeline( @@ -164,6 +173,4 @@ def inspection( if __name__ == "__main__": - """Default CLI entrypoint for Click interface - """ cli() diff --git a/heidgaf/detectors/arima_anomaly_detector.py b/heidgaf/detectors/arima_anomaly_detector.py index cefbd18..d0d9af4 100644 --- a/heidgaf/detectors/arima_anomaly_detector.py +++ b/heidgaf/detectors/arima_anomaly_detector.py @@ -5,8 +5,7 @@ import numpy as np from statsmodels.tsa.arima.model import ARIMA -from heidgaf.detectors.base_anomaly import (AnomalyDetector, - AnomalyDetectorConfig) +from heidgaf.detectors.base_anomaly import AnomalyDetector, AnomalyDetectorConfig class ARIMAAnomalyDetector(AnomalyDetector): diff --git a/heidgaf/detectors/exponential_thresholding.py b/heidgaf/detectors/exponential_thresholding.py index 56eacce..a16a784 100644 --- a/heidgaf/detectors/exponential_thresholding.py +++ b/heidgaf/detectors/exponential_thresholding.py @@ -2,8 +2,7 @@ import numpy as np -from heidgaf.detectors.base_anomaly import (AnomalyDetector, - AnomalyDetectorConfig) +from heidgaf.detectors.base_anomaly import AnomalyDetector, AnomalyDetectorConfig class EMAAnomalyDetector(AnomalyDetector): diff --git a/heidgaf/detectors/real_time_anomaly.py b/heidgaf/detectors/real_time_anomaly.py index 301f6e6..de4b652 100644 --- a/heidgaf/detectors/real_time_anomaly.py +++ b/heidgaf/detectors/real_time_anomaly.py @@ -2,8 +2,7 @@ import numpy as np -from heidgaf.detectors.base_anomaly import (AnomalyDetector, - AnomalyDetectorConfig) +from heidgaf.detectors.base_anomaly import AnomalyDetector, AnomalyDetectorConfig class RealTimeAnomalyDetector(AnomalyDetector): diff --git a/heidgaf/detectors/thresholding_algorithm.py b/heidgaf/detectors/thresholding_algorithm.py index 17dc090..bd6c16f 100644 --- a/heidgaf/detectors/thresholding_algorithm.py +++ b/heidgaf/detectors/thresholding_algorithm.py @@ -2,8 +2,7 @@ import numpy as np -from heidgaf.detectors.base_anomaly import (AnomalyDetector, - AnomalyDetectorConfig) +from heidgaf.detectors.base_anomaly import AnomalyDetector, AnomalyDetectorConfig class ThresholdingAnomalyDetector(AnomalyDetector): diff --git a/heidgaf/inspectors/__init__.py b/heidgaf/inspectors/__init__.py index afe0ede..a98111a 100644 --- a/heidgaf/inspectors/__init__.py +++ b/heidgaf/inspectors/__init__.py @@ -60,7 +60,7 @@ def __init__(self, config: InspectorConfig) -> None: "thirdleveldomain", "secondleveldomain", "fqdn", - "tld" + "tld", ] ), mean_imputer=Imputer(features_to_impute=[], strategy="mean"), @@ -99,7 +99,7 @@ def warnings(self, data: pl.DataFrame, suspicious: List, id: str) -> pl.DataFram .alias("distro") ) fqdn_distro = fqdn_distro.filter(pl.col("distro") > 0.05) - + # Initialize empty array total_warnings = [data.clear()] @@ -126,9 +126,8 @@ def warnings(self, data: pl.DataFrame, suspicious: List, id: str) -> pl.DataFram with pl.Config(tbl_rows=100): logging.debug(suspicious_data.select(["fqdn"]).unique()) total_warnings.append(suspicious_data) - + return pl.concat(total_warnings) - def update_count( self, diff --git a/heidgaf/inspectors/domain_analyzer.py b/heidgaf/inspectors/domain_analyzer.py index abac483..c831766 100644 --- a/heidgaf/inspectors/domain_analyzer.py +++ b/heidgaf/inspectors/domain_analyzer.py @@ -11,6 +11,7 @@ class DomainInspector(Inspector): Args: Tester (Tester): Configuration. """ + KEY_SECOND_LEVEL_DOMAIN = "secondleveldomain_frequency" KEY_THIRD_LEVEL_DOMAIN = "thirdleveldomain_frequency" KEY_FQDN = "fqdn_frequency" @@ -37,14 +38,14 @@ def run(self, data: pl.DataFrame) -> pl.DataFrame: """ min_date = data.select(["timestamp"]).min().item() max_date = data.select(["timestamp"]).max().item() - + # Filter data with no errors df = data.filter(pl.col("query") != "|").filter( pl.col("query").str.split(".").list.len() != 1 ) - + findings = [] - + # Check anomalies in FQDN logging.info("Analyze FQDN request anomalies") warnings = self.update_count(df, min_date, max_date, "fqdn", self.KEY_FQDN) @@ -63,5 +64,5 @@ def run(self, data: pl.DataFrame) -> pl.DataFrame: df, min_date, max_date, "thirdleveldomain", self.KEY_THIRD_LEVEL_DOMAIN ) findings.append(self.warnings(data, warnings, "thirdleveldomain")) - + return pl.concat(findings) diff --git a/heidgaf/inspectors/ip_analyzer.py b/heidgaf/inspectors/ip_analyzer.py index 87b858f..89cd8e5 100644 --- a/heidgaf/inspectors/ip_analyzer.py +++ b/heidgaf/inspectors/ip_analyzer.py @@ -39,7 +39,7 @@ def run(self, data: pl.DataFrame) -> pl.DataFrame: .filter(pl.col("return_code") != ReturnCode.NOERROR.value) .filter(pl.col("query").str.split(".").list.len() != 1) ) - + findings = [] # Update frequencies based on errors @@ -54,5 +54,5 @@ def run(self, data: pl.DataFrame) -> pl.DataFrame: df, min_date, max_date, "dns_server", self.KEY_DNS_SERVER ) findings.append(self.warnings(data, warnings, "dns_server")) - + return pl.concat(findings) diff --git a/heidgaf/main.py b/heidgaf/main.py index d16ac8c..dbedb4a 100644 --- a/heidgaf/main.py +++ b/heidgaf/main.py @@ -13,8 +13,7 @@ from heidgaf.detectors.arima_anomaly_detector import ARIMAAnomalyDetector from heidgaf.detectors.base_anomaly import AnomalyDetectorConfig from heidgaf.detectors.exponential_thresholding import EMAAnomalyDetector -from heidgaf.detectors.thresholding_algorithm import \ - ThresholdingAnomalyDetector +from heidgaf.detectors.thresholding_algorithm import ThresholdingAnomalyDetector from heidgaf.inspectors import Inspector, InspectorConfig from heidgaf.inspectors.domain_analyzer import DomainInspector from heidgaf.inspectors.ip_analyzer import IPInspector @@ -38,7 +37,7 @@ class FileType(str, Enum): class Separator(str, Enum): SPACE = " " COMMA = "," - + class InspectorFactory: def __init__(self, config) -> None: @@ -47,7 +46,7 @@ def __init__(self, config) -> None: "IP": (IPInspector(config)), "Domain": (DomainInspector(config)), } - + def __getitem__(self, key: str) -> Inspector: if key in self.factory: return self.factory[key] @@ -56,10 +55,11 @@ def __getitem__(self, key: str) -> Inspector: f"source {key} is not supported. Please pass a valid source." ) + class DNSInspectorPipeline: """Main analyzer pipeline. It loads new data and processes it through our analyzers. If an anomaly occurs, our models run""" - - MODELS_URL="https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021" + + MODELS_URL = "https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021" def __init__( self, @@ -76,7 +76,7 @@ def __init__( redis_db=0, redis_max_connections=20, threshold=5, - model=Model.RANDOM_FOREST_CLASSIFIER + model=Model.RANDOM_FOREST_CLASSIFIER, ) -> None: try: self.df_cache = DataFrameRedisCache( @@ -85,7 +85,6 @@ def __init__( except redis.exceptions.ConnectionError: logging.warning("No connection to Redis host") self.df_cache = None - if os.path.isfile(path): logging.debug(f"Processing files: {path}") @@ -103,7 +102,6 @@ def __init__( self.threshold = threshold self.order = order self.model = self.__get_model(model) - def load_data(self, path: str, separator: str) -> pl.DataFrame: """Loads data from csv files @@ -139,7 +137,7 @@ def load_data(self, path: str, separator: str) -> pl.DataFrame: (pl.col("query").str.split(".").alias("labels")), ] ) - + x = x.filter(pl.col("query").str.len_chars() > 0) x = x.filter(pl.col("labels").list.len() > 1) @@ -211,18 +209,18 @@ def run(self): raise NotImplementedError(f"Detector not implemented!") # Run inspectors to find anomalies in data - config = InspectorConfig( - detector, self.df_cache, self.threshold, self.model - ) + config = InspectorConfig(detector, self.df_cache, self.threshold, self.model) factory = InspectorFactory(config) errors = [] for inspector in ["IP", "Domain"]: errors.append(factory[inspector].run(self.data)) - - errors_pl: pl.DataFrame = pl.concat(errors) - - group_errors_pl = errors_pl.group_by(["client_ip", "fqdn"]).count().sort("client_ip") - with pl.Config(tbl_rows=100): + + errors_pl: pl.DataFrame = pl.concat(errors) + + group_errors_pl = ( + errors_pl.group_by(["client_ip", "fqdn"]).count().sort("client_ip") + ) + with pl.Config(fmt_str_lengths=1000): logging.warning(group_errors_pl) def __get_model(self, model_type: Model): @@ -234,11 +232,13 @@ def __get_model(self, model_type: Model): Returns: model: Model to predict data. """ - response = requests.get(f"{self.MODELS_URL}/files/?p=%2F{model_type.value}.pkl&dl=1") - + response = requests.get( + f"{self.MODELS_URL}/files/?p=%2F{model_type.value}.pkl&dl=1" + ) + response.raise_for_status() - - with open(rf'/tmp/{model_type.value}.pkl', 'wb') as f: + + with open(rf"/tmp/{model_type.value}.pkl", "wb") as f: f.write(response.content) - - return joblib.load(f'/tmp/{model_type.value}.pkl') \ No newline at end of file + + return joblib.load(f"/tmp/{model_type.value}.pkl") diff --git a/heidgaf/models/__init__.py b/heidgaf/models/__init__.py index f98fbd0..a8e1c79 100644 --- a/heidgaf/models/__init__.py +++ b/heidgaf/models/__init__.py @@ -36,7 +36,7 @@ def __init__( self.mean_imputer = mean_imputer self.target_encoder = target_encoder self.clf = clf - + # setting device on GPU if available, else CPU self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logging.info(f"Using device: {self.device}") @@ -74,7 +74,7 @@ def fit(self, x_train: pl.DataFrame, y_train: pl.DataFrame): verbose=2, n_jobs=-1, ) - + start = time() model = clf.fit(x_train.to_numpy(), y_train.to_numpy().ravel()) logging.info( @@ -125,13 +125,12 @@ def predict(self, x): "search": { "eta": list(np.linspace(0.1, 0.6, 6)), "gamma": [int(x) for x in np.linspace(0, 10, 10)], - 'learning_rate': [0.03, 0.01, 0.003, 0.001], - 'min_child_weight': [1,3, 5,7, 10], - 'subsample': [0.6, 0.8, 1.0, 1.2, 1.4], - 'colsample_bytree': [0.6, 0.8, 1.0, 1.2, 1.4], - 'max_depth': [3, 4, 5, 6, 7, 8, 9 ,10, 12, 14], - 'reg_lambda':np.array([0.4, 0.6, 0.8, 1, 1.2, 1.4]) - + "learning_rate": [0.03, 0.01, 0.003, 0.001], + "min_child_weight": [1, 3, 5, 7, 10], + "subsample": [0.6, 0.8, 1.0, 1.2, 1.4], + "colsample_bytree": [0.6, 0.8, 1.0, 1.2, 1.4], + "max_depth": [3, 4, 5, 6, 7, 8, 9, 10, 12, 14], + "reg_lambda": np.array([0.4, 0.6, 0.8, 1, 1.2, 1.4]), }, } xgboost_rf_model = { @@ -145,9 +144,9 @@ def predict(self, x): random_forest_model = { "model": RandomForestClassifier(), "search": { - "n_estimators": [int(x) for x in np.linspace(start = 200, stop = 1000, num = 10)], - "max_features": [42], - "max_depth": [int(x) for x in np.linspace(10, 110, num = 11)], #.append(None), + "n_estimators": [int(x) for x in np.linspace(start=200, stop=1000, num=10)], + "max_features": [42], + "max_depth": [int(x) for x in np.linspace(10, 110, num=11)], # .append(None), "min_samples_split": [2, 5, 10], "min_samples_leaf": [1, 2, 4], "bootstrap": [True, False], diff --git a/requirements.txt b/requirements.txt index 30f7036..c266fd0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,4 +13,4 @@ pytest feature-engineering-polars # cupy-cuda12x statsmodels -xgboost \ No newline at end of file +xgboost diff --git a/tests.py b/tests.py index 787f76f..75adf68 100644 --- a/tests.py +++ b/tests.py @@ -4,4 +4,3 @@ # currently default test to make pipeline happy def test_source_parameter(): assert True - \ No newline at end of file