From 363c2586bfda592a7ebb6b559e1772cc7aa6f43e Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Thu, 21 Mar 2024 09:08:40 +0100 Subject: [PATCH 1/6] print results table at the end of an evaluation session --- src/eva/core/trainers/_recorder.py | 65 ++++++++++++++++++++++++++--- src/eva/core/trainers/functional.py | 2 +- 2 files changed, 61 insertions(+), 6 deletions(-) diff --git a/src/eva/core/trainers/_recorder.py b/src/eva/core/trainers/_recorder.py index 64950729..783e05f8 100644 --- a/src/eva/core/trainers/_recorder.py +++ b/src/eva/core/trainers/_recorder.py @@ -5,18 +5,41 @@ import os import statistics import sys -from typing import Any, Dict, List, Mapping +from typing import Any, Dict, List, Mapping, TypedDict from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT from lightning_fabric.utilities import cloud_io from loguru import logger from omegaconf import OmegaConf +from rich import console as rich_console +from rich import table as rich_table from toolz import dicttoolz SESSION_METRICS = Mapping[str, List[float]] """Session metrics type-hint.""" +class SESSION_STATISTICS(TypedDict): + """Type-hint for aggregated metrics of multiple runs with mean & stdev.""" + + mean: float + stdev: float + values: List[float] + + +class STAGE_RESULTS(TypedDict): + """Type-hint for metrics statstics for val & test stages.""" + + val: List[Dict[str, SESSION_STATISTICS]] + test: List[Dict[str, SESSION_STATISTICS]] + + +class RESULTS_DICT(TypedDict): + """Type-hint for the final results dictionary.""" + + metrics: STAGE_RESULTS + + class SessionRecorder: """Multi-run (session) summary logger.""" @@ -67,13 +90,13 @@ def update( self._update_validation_metrics(validation_scores) self._update_test_metrics(test_scores) - def compute(self) -> Dict[str, List[Dict[str, Any]]]: + def compute(self) -> STAGE_RESULTS: """Computes and returns the session statistics.""" validation_statistics = list(map(_calculate_statistics, self._validation_metrics)) test_statistics = list(map(_calculate_statistics, self._test_metrics)) return {"val": validation_statistics, "test": test_statistics} - def export(self) -> Dict[str, Any]: + def export(self) -> RESULTS_DICT: """Exports the results.""" statistics = self.compute() return {"metrics": statistics} @@ -81,6 +104,7 @@ def export(self) -> Dict[str, Any]: def save(self) -> None: """Saves the recorded results.""" results = self.export() + _print_results(results) _save_json(results, self.filename) self._save_config() @@ -125,10 +149,10 @@ def _init_session_metrics(n_datasets: int) -> List[SESSION_METRICS]: return [collections.defaultdict(list) for _ in range(n_datasets)] -def _calculate_statistics(session_metrics: SESSION_METRICS) -> Dict[str, float | List[float]]: +def _calculate_statistics(session_metrics: SESSION_METRICS) -> Dict[str, SESSION_STATISTICS]: """Calculate the metric statistics of a dataset session run.""" - def _calculate_metric_statistics(values: List[float]) -> Dict[str, float | List[float]]: + def _calculate_metric_statistics(values: List[float]) -> SESSION_STATISTICS: """Calculates and returns the metric statistics.""" mean = statistics.mean(values) stdev = statistics.stdev(values) if len(values) > 1 else 0 @@ -147,3 +171,34 @@ def _save_json(data: Dict[str, Any], save_as: str = "data.json"): fs.makedirs(output_dir, exist_ok=True) with fs.open(save_as, "w") as file: json.dump(data, file, indent=4, sort_keys=True) + + +def _print_results(results: RESULTS_DICT) -> None: + """Prints the results to the console.""" + for stage in ["val", "test"]: + for dataset_idx in range(len(results["metrics"][stage])): + _print_table(results["metrics"][stage][dataset_idx], stage, dataset_idx) + + +def _print_table(metrics_dict: Dict[str, SESSION_STATISTICS], stage: str, dataset_idx: int): + """Prints the metrics of a single dataset as a table.""" + metrics_table = rich_table.Table( + title=f"\n{stage.capitalize()} Dataset {dataset_idx}", title_style="bold" + ) + metrics_table.add_column("Metric", style="cyan") + metrics_table.add_column("Mean", justify="right", style="magenta") + metrics_table.add_column("Stdev", justify="right", style="magenta") + + n_runs = len(metrics_dict[next(iter(metrics_dict))]["values"]) + for i in range(n_runs): + metrics_table.add_column(f"Run {i}", justify="right", style="magenta") + + for metric_name, metric_dict in metrics_dict.items(): + row = [metric_name, metric_dict["mean"], metric_dict["stdev"]] + [ + metric_dict["values"][i] for i in range(n_runs) + ] + row = [str(entry) for entry in row] + metrics_table.add_row(*row) + + console = rich_console.Console() + console.print(metrics_table) diff --git a/src/eva/core/trainers/functional.py b/src/eva/core/trainers/functional.py index 00f81f5e..9e630524 100644 --- a/src/eva/core/trainers/functional.py +++ b/src/eva/core/trainers/functional.py @@ -82,7 +82,7 @@ def fit_and_validate( A tuple of with the validation and the test metrics (if exists). """ trainer.fit(model, datamodule=datamodule) - validation_scores = trainer.validate(datamodule=datamodule) + validation_scores = trainer.validate(datamodule=datamodule, verbose=False) test_scores = None if datamodule.datasets.test is None else trainer.test(datamodule=datamodule) return validation_scores, test_scores From 35ead6efc319284485a0aacb3b272c490993a0ea Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Thu, 21 Mar 2024 09:19:41 +0100 Subject: [PATCH 2/6] added rich & fixed type hint --- pdm.lock | 22 +++++++++++----------- pyproject.toml | 1 + src/eva/core/trainers/_recorder.py | 4 ++-- src/eva/core/trainers/functional.py | 6 +++++- 4 files changed, 19 insertions(+), 14 deletions(-) diff --git a/pdm.lock b/pdm.lock index 8dc4367f..086bd069 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "docs", "all", "typecheck", "lint", "vision", "test"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:4f1bba07786b708fe7be9ff1c02c816a26da0556ab6f29352786ce4f468501df" +content_hash = "sha256:d4a0761420ef16deca511b3f9846a2b1061e03116ca8100f3134f0e31ebe203d" [[package]] name = "absl-py" @@ -842,7 +842,7 @@ name = "markdown-it-py" version = "3.0.0" requires_python = ">=3.8" summary = "Python port of markdown-it. Markdown parsing, done right!" -groups = ["dev", "lint"] +groups = ["default", "dev", "lint"] dependencies = [ "mdurl~=0.1", ] @@ -896,7 +896,7 @@ name = "mdurl" version = "0.1.2" requires_python = ">=3.7" summary = "Markdown URL utilities" -groups = ["dev", "lint"] +groups = ["default", "dev", "lint"] files = [ {file = "mdurl-0.1.2-py3-none-any.whl", hash = "sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8"}, {file = "mdurl-0.1.2.tar.gz", hash = "sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba"}, @@ -1689,7 +1689,7 @@ name = "pygments" version = "2.17.2" requires_python = ">=3.7" summary = "Pygments is a syntax highlighting package written in Python." -groups = ["dev", "docs", "lint", "test"] +groups = ["default", "dev", "docs", "lint", "test"] files = [ {file = "pygments-2.17.2-py3-none-any.whl", hash = "sha256:b27c2826c47d0f3219f29554824c30c5e8945175d888647acd804ddd04af846c"}, {file = "pygments-2.17.2.tar.gz", hash = "sha256:da46cec9fd2de5be3a8a784f434e4c4ab670b4ff54d605c4c2717e9d49c4c367"}, @@ -1734,7 +1734,7 @@ files = [ [[package]] name = "pyright" -version = "1.1.354" +version = "1.1.355" requires_python = ">=3.7" summary = "Command line wrapper for pyright" groups = ["dev", "typecheck"] @@ -1742,8 +1742,8 @@ dependencies = [ "nodeenv>=1.6.0", ] files = [ - {file = "pyright-1.1.354-py3-none-any.whl", hash = "sha256:f28d61ae8ae035fc52ded1070e8d9e786051a26a4127bbd7a4ba0399b81b37b5"}, - {file = "pyright-1.1.354.tar.gz", hash = "sha256:b1070dc774ff2e79eb0523fe87f4ba9a90550de7e4b030a2bc9e031864029a1f"}, + {file = "pyright-1.1.355-py3-none-any.whl", hash = "sha256:bf30b6728fd68ae7d09c98292b67152858dd89738569836896df786e52b5fe48"}, + {file = "pyright-1.1.355.tar.gz", hash = "sha256:dca4104cd53d6484e6b1b50b7a239ad2d16d2ffd20030bcf3111b56f44c263bf"}, ] [[package]] @@ -1951,7 +1951,7 @@ name = "rich" version = "13.7.1" requires_python = ">=3.7.0" summary = "Render rich text, tables, progress bars, syntax highlighting, markdown and more to the terminal" -groups = ["dev", "lint"] +groups = ["default", "dev", "lint"] dependencies = [ "markdown-it-py>=2.2.0", "pygments<3.0.0,>=2.13.0", @@ -2386,7 +2386,7 @@ files = [ [[package]] name = "transformers" -version = "4.38.2" +version = "4.39.0" requires_python = ">=3.8.0" summary = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" groups = ["default"] @@ -2403,8 +2403,8 @@ dependencies = [ "tqdm>=4.27", ] files = [ - {file = "transformers-4.38.2-py3-none-any.whl", hash = "sha256:c4029cb9f01b3dd335e52f364c52d2b37c65b4c78e02e6a08b1919c5c928573e"}, - {file = "transformers-4.38.2.tar.gz", hash = "sha256:c5fc7ad682b8a50a48b2a4c05d4ea2de5567adb1bdd00053619dbe5960857dd5"}, + {file = "transformers-4.39.0-py3-none-any.whl", hash = "sha256:7801785b1f016d667467e8c372c1c3653c18fe32ba97952059e3bea79ba22b08"}, + {file = "transformers-4.39.0.tar.gz", hash = "sha256:517a13cd633b10bea01c92ab0b3059762872c7c29da3d223db9d28e926fe330d"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index 77fed9bc..28ea5360 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "onnxruntime>=1.17.1", "onnx>=1.15.0", "toolz>=0.12.1", + "rich>=13.7.1", ] [project.urls] diff --git a/src/eva/core/trainers/_recorder.py b/src/eva/core/trainers/_recorder.py index 783e05f8..74e16d74 100644 --- a/src/eva/core/trainers/_recorder.py +++ b/src/eva/core/trainers/_recorder.py @@ -5,7 +5,7 @@ import os import statistics import sys -from typing import Any, Dict, List, Mapping, TypedDict +from typing import Dict, List, Mapping, TypedDict from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT from lightning_fabric.utilities import cloud_io @@ -161,7 +161,7 @@ def _calculate_metric_statistics(values: List[float]) -> SESSION_STATISTICS: return dicttoolz.valmap(_calculate_metric_statistics, session_metrics) -def _save_json(data: Dict[str, Any], save_as: str = "data.json"): +def _save_json(data: RESULTS_DICT, save_as: str = "data.json"): """Saves data to a json file.""" if not save_as.endswith(".json"): raise ValueError() diff --git a/src/eva/core/trainers/functional.py b/src/eva/core/trainers/functional.py index 9e630524..69889115 100644 --- a/src/eva/core/trainers/functional.py +++ b/src/eva/core/trainers/functional.py @@ -83,7 +83,11 @@ def fit_and_validate( """ trainer.fit(model, datamodule=datamodule) validation_scores = trainer.validate(datamodule=datamodule, verbose=False) - test_scores = None if datamodule.datasets.test is None else trainer.test(datamodule=datamodule) + test_scores = ( + None + if datamodule.datasets.test is None + else trainer.test(datamodule=datamodule, verbose=False) + ) return validation_scores, test_scores From eeb3ef0b7eced46c80b43c32041468133caa597a Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Thu, 21 Mar 2024 09:22:33 +0100 Subject: [PATCH 3/6] added print results to the end --- src/eva/core/trainers/_recorder.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/eva/core/trainers/_recorder.py b/src/eva/core/trainers/_recorder.py index 74e16d74..5c4f1e12 100644 --- a/src/eva/core/trainers/_recorder.py +++ b/src/eva/core/trainers/_recorder.py @@ -104,9 +104,9 @@ def export(self) -> RESULTS_DICT: def save(self) -> None: """Saves the recorded results.""" results = self.export() - _print_results(results) _save_json(results, self.filename) self._save_config() + _print_results(results) def reset(self) -> None: """Resets the state of the tracked metrics.""" @@ -175,9 +175,12 @@ def _save_json(data: RESULTS_DICT, save_as: str = "data.json"): def _print_results(results: RESULTS_DICT) -> None: """Prints the results to the console.""" - for stage in ["val", "test"]: - for dataset_idx in range(len(results["metrics"][stage])): - _print_table(results["metrics"][stage][dataset_idx], stage, dataset_idx) + try: + for stage in ["val", "test"]: + for dataset_idx in range(len(results["metrics"][stage])): + _print_table(results["metrics"][stage][dataset_idx], stage, dataset_idx) + except Exception as e: + logger.error(f"Failed to print the results: {e}") def _print_table(metrics_dict: Dict[str, SESSION_STATISTICS], stage: str, dataset_idx: int): From d6056d0ca9d571da7c4170ee3c6decc2c456a2b5 Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Mon, 25 Mar 2024 09:10:57 +0100 Subject: [PATCH 4/6] updated results tables to contain only 3 columns --- src/eva/core/trainers/_recorder.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/eva/core/trainers/_recorder.py b/src/eva/core/trainers/_recorder.py index 5c4f1e12..8e85ed99 100644 --- a/src/eva/core/trainers/_recorder.py +++ b/src/eva/core/trainers/_recorder.py @@ -189,18 +189,18 @@ def _print_table(metrics_dict: Dict[str, SESSION_STATISTICS], stage: str, datase title=f"\n{stage.capitalize()} Dataset {dataset_idx}", title_style="bold" ) metrics_table.add_column("Metric", style="cyan") - metrics_table.add_column("Mean", justify="right", style="magenta") - metrics_table.add_column("Stdev", justify="right", style="magenta") + metrics_table.add_column("Mean", style="magenta") + metrics_table.add_column("Stdev", style="magenta") + metrics_table.add_column("All", style="magenta") n_runs = len(metrics_dict[next(iter(metrics_dict))]["values"]) - for i in range(n_runs): - metrics_table.add_column(f"Run {i}", justify="right", style="magenta") - for metric_name, metric_dict in metrics_dict.items(): - row = [metric_name, metric_dict["mean"], metric_dict["stdev"]] + [ - metric_dict["values"][i] for i in range(n_runs) + row = [ + metric_name, + f'{metric_dict["mean"]:.3f}', + f'{metric_dict["stdev"]:.3f}', + ", ".join(f'{metric_dict["values"][i]:.3f}' for i in range(n_runs)), ] - row = [str(entry) for entry in row] metrics_table.add_row(*row) console = rich_console.Console() From fb738e8b0e3e67e5a0464c38d86219fd005d49ab Mon Sep 17 00:00:00 2001 From: ioangatop Date: Wed, 27 Mar 2024 09:16:26 +0100 Subject: [PATCH 5/6] updates --- configs/vision/dino_vit/offline/bach.yaml | 7 +++++-- src/eva/core/trainers/_recorder.py | 6 +++++- src/eva/core/trainers/functional.py | 13 ++++++++----- src/eva/core/trainers/trainer.py | 2 ++ 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/configs/vision/dino_vit/offline/bach.yaml b/configs/vision/dino_vit/offline/bach.yaml index 3d1dd721..2a1159f8 100644 --- a/configs/vision/dino_vit/offline/bach.yaml +++ b/configs/vision/dino_vit/offline/bach.yaml @@ -2,9 +2,12 @@ trainer: class_path: eva.Trainer init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 5} + n_runs: &N_RUNS ${oc.env:N_RUNS, 1} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:DINO_BACKBONE, dino_vits16}/offline/bach} - max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + # max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} + max_steps: &MAX_STEPS 10 + # limit_train_batches: 2 + # limit_val_batches: 2 callbacks: - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: diff --git a/src/eva/core/trainers/_recorder.py b/src/eva/core/trainers/_recorder.py index 8e85ed99..f20ff717 100644 --- a/src/eva/core/trainers/_recorder.py +++ b/src/eva/core/trainers/_recorder.py @@ -48,6 +48,7 @@ def __init__( output_dir: str, results_file: str = "results.json", config_file: str = "config.yaml", + verbose: bool = True, ) -> None: """Initializes the recorder. @@ -55,10 +56,12 @@ def __init__( output_dir: The destination folder to save the results. results_file: The name of the results json file. config_file: The name of the yaml configuration file. + verbose: Whether to print the session metrics. """ self._output_dir = output_dir self._results_file = results_file self._config_file = config_file + self._verbose = verbose self._validation_metrics: List[SESSION_METRICS] = [] self._test_metrics: List[SESSION_METRICS] = [] @@ -106,7 +109,8 @@ def save(self) -> None: results = self.export() _save_json(results, self.filename) self._save_config() - _print_results(results) + if self._verbose: + _print_results(results) def reset(self) -> None: """Resets the state of the tracked metrics.""" diff --git a/src/eva/core/trainers/functional.py b/src/eva/core/trainers/functional.py index 69889115..900a8cd4 100644 --- a/src/eva/core/trainers/functional.py +++ b/src/eva/core/trainers/functional.py @@ -16,6 +16,7 @@ def run_evaluation_session( datamodule: datamodules.DataModule, *, n_runs: int = 1, + verbose: bool = True, ) -> None: """Runs a downstream evaluation session out-of-place. @@ -29,11 +30,12 @@ def run_evaluation_session( base_model: The base model module to use. datamodule: The data module. n_runs: The amount of runs (fit and evaluate) to perform. + verbose: Whether to verbose the session metrics. """ - recorder = _recorder.SessionRecorder(output_dir=base_trainer.default_log_dir) + recorder = _recorder.SessionRecorder(output_dir=base_trainer.default_log_dir, verbose=verbose) for run_index in range(n_runs): validation_scores, test_scores = run_evaluation( - base_trainer, base_model, datamodule, run_id=f"run_{run_index}" + base_trainer, base_model, datamodule, run_id=f"run_{run_index}", verbose=not verbose, ) recorder.update(validation_scores, test_scores) recorder.save() @@ -60,13 +62,14 @@ def run_evaluation( """ trainer, model = _utils.clone(base_trainer, base_model) trainer.setup_log_dirs(run_id or "") - return fit_and_validate(trainer, model, datamodule) + return fit_and_validate(trainer, model, datamodule, verbose=True) def fit_and_validate( trainer: eva_trainer.Trainer, model: modules.ModelModule, datamodule: datamodules.DataModule, + verbose: bool = True, ) -> Tuple[_EVALUATE_OUTPUT, _EVALUATE_OUTPUT | None]: """Fits and evaluates a model in-place. @@ -82,11 +85,11 @@ def fit_and_validate( A tuple of with the validation and the test metrics (if exists). """ trainer.fit(model, datamodule=datamodule) - validation_scores = trainer.validate(datamodule=datamodule, verbose=False) + validation_scores = trainer.validate(datamodule=datamodule, verbose=verbose) test_scores = ( None if datamodule.datasets.test is None - else trainer.test(datamodule=datamodule, verbose=False) + else trainer.test(datamodule=datamodule, verbose=verbose) ) return validation_scores, test_scores diff --git a/src/eva/core/trainers/trainer.py b/src/eva/core/trainers/trainer.py index 25877fc9..db7a774c 100644 --- a/src/eva/core/trainers/trainer.py +++ b/src/eva/core/trainers/trainer.py @@ -77,6 +77,7 @@ def run_evaluation_session( self, model: modules.ModelModule, datamodule: datamodules.DataModule, + verbose: bool = True, ) -> None: """Runs an evaluation session out-of-place. @@ -94,4 +95,5 @@ def run_evaluation_session( base_model=model, datamodule=datamodule, n_runs=self._n_runs, + verbose=False, ) From 51d1951a3b02e8ea03c7d85f3b7066e81651e8cb Mon Sep 17 00:00:00 2001 From: ioangatop Date: Thu, 28 Mar 2024 14:15:51 +0100 Subject: [PATCH 6/6] update verbose ops --- configs/vision/dino_vit/offline/bach.yaml | 7 ++----- src/eva/core/trainers/_recorder.py | 2 +- src/eva/core/trainers/functional.py | 16 +++++++++++++--- src/eva/core/trainers/trainer.py | 3 +-- 4 files changed, 17 insertions(+), 11 deletions(-) diff --git a/configs/vision/dino_vit/offline/bach.yaml b/configs/vision/dino_vit/offline/bach.yaml index 2a1159f8..3d1dd721 100644 --- a/configs/vision/dino_vit/offline/bach.yaml +++ b/configs/vision/dino_vit/offline/bach.yaml @@ -2,12 +2,9 @@ trainer: class_path: eva.Trainer init_args: - n_runs: &N_RUNS ${oc.env:N_RUNS, 1} + n_runs: &N_RUNS ${oc.env:N_RUNS, 5} default_root_dir: &OUTPUT_ROOT ${oc.env:OUTPUT_ROOT, logs/${oc.env:DINO_BACKBONE, dino_vits16}/offline/bach} - # max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} - max_steps: &MAX_STEPS 10 - # limit_train_batches: 2 - # limit_val_batches: 2 + max_steps: &MAX_STEPS ${oc.env:MAX_STEPS, 12500} callbacks: - class_path: lightning.pytorch.callbacks.LearningRateMonitor init_args: diff --git a/src/eva/core/trainers/_recorder.py b/src/eva/core/trainers/_recorder.py index f20ff717..121dd713 100644 --- a/src/eva/core/trainers/_recorder.py +++ b/src/eva/core/trainers/_recorder.py @@ -174,7 +174,7 @@ def _save_json(data: RESULTS_DICT, save_as: str = "data.json"): fs = cloud_io.get_filesystem(output_dir, anon=False) fs.makedirs(output_dir, exist_ok=True) with fs.open(save_as, "w") as file: - json.dump(data, file, indent=4, sort_keys=True) + json.dump(data, file, indent=2, sort_keys=True) def _print_results(results: RESULTS_DICT) -> None: diff --git a/src/eva/core/trainers/functional.py b/src/eva/core/trainers/functional.py index 900a8cd4..582601f0 100644 --- a/src/eva/core/trainers/functional.py +++ b/src/eva/core/trainers/functional.py @@ -30,12 +30,17 @@ def run_evaluation_session( base_model: The base model module to use. datamodule: The data module. n_runs: The amount of runs (fit and evaluate) to perform. - verbose: Whether to verbose the session metrics. + verbose: Whether to verbose the session metrics instead of + these of each individual runs and vice-versa. """ recorder = _recorder.SessionRecorder(output_dir=base_trainer.default_log_dir, verbose=verbose) for run_index in range(n_runs): validation_scores, test_scores = run_evaluation( - base_trainer, base_model, datamodule, run_id=f"run_{run_index}", verbose=not verbose, + base_trainer, + base_model, + datamodule, + run_id=f"run_{run_index}", + verbose=not verbose, ) recorder.update(validation_scores, test_scores) recorder.save() @@ -47,6 +52,7 @@ def run_evaluation( datamodule: datamodules.DataModule, *, run_id: str | None = None, + verbose: bool = True, ) -> Tuple[_EVALUATE_OUTPUT, _EVALUATE_OUTPUT | None]: """Fits and evaluates a model out-of-place. @@ -56,13 +62,15 @@ def run_evaluation( datamodule: The data module. run_id: The run id to be appended to the output log directory. If `None`, it will use the log directory of the trainer as is. + verbose: Whether to print the validation and test metrics + in the end of the training. Returns: A tuple of with the validation and the test metrics (if exists). """ trainer, model = _utils.clone(base_trainer, base_model) trainer.setup_log_dirs(run_id or "") - return fit_and_validate(trainer, model, datamodule, verbose=True) + return fit_and_validate(trainer, model, datamodule, verbose=verbose) def fit_and_validate( @@ -80,6 +88,8 @@ def fit_and_validate( trainer: The trainer module to use and update in-place. model: The model module to use and update in-place. datamodule: The data module. + verbose: Whether to print the validation and test metrics + in the end of the training. Returns: A tuple of with the validation and the test metrics (if exists). diff --git a/src/eva/core/trainers/trainer.py b/src/eva/core/trainers/trainer.py index db7a774c..1a7af4e1 100644 --- a/src/eva/core/trainers/trainer.py +++ b/src/eva/core/trainers/trainer.py @@ -77,7 +77,6 @@ def run_evaluation_session( self, model: modules.ModelModule, datamodule: datamodules.DataModule, - verbose: bool = True, ) -> None: """Runs an evaluation session out-of-place. @@ -95,5 +94,5 @@ def run_evaluation_session( base_model=model, datamodule=datamodule, n_runs=self._n_runs, - verbose=False, + verbose=self._n_runs > 1, )