Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Print results table at the end of an evaluation session #337

Merged
merged 8 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ dependencies = [
"onnxruntime>=1.17.1",
"onnx>=1.15.0",
"toolz>=0.12.1",
"rich>=13.7.1",
]

[project.urls]
Expand Down
76 changes: 69 additions & 7 deletions src/eva/core/trainers/_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,41 @@
import os
import statistics
import sys
from typing import Any, Dict, List, Mapping
from typing import Dict, List, Mapping, TypedDict

from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT
from lightning_fabric.utilities import cloud_io
from loguru import logger
from omegaconf import OmegaConf
from rich import console as rich_console
from rich import table as rich_table
from toolz import dicttoolz

SESSION_METRICS = Mapping[str, List[float]]
"""Session metrics type-hint."""


class SESSION_STATISTICS(TypedDict):
"""Type-hint for aggregated metrics of multiple runs with mean & stdev."""

mean: float
stdev: float
values: List[float]


class STAGE_RESULTS(TypedDict):
"""Type-hint for metrics statstics for val & test stages."""

val: List[Dict[str, SESSION_STATISTICS]]
test: List[Dict[str, SESSION_STATISTICS]]


class RESULTS_DICT(TypedDict):
"""Type-hint for the final results dictionary."""

metrics: STAGE_RESULTS


class SessionRecorder:
"""Multi-run (session) summary logger."""

Expand All @@ -25,17 +48,20 @@ def __init__(
output_dir: str,
results_file: str = "results.json",
config_file: str = "config.yaml",
verbose: bool = True,
) -> None:
"""Initializes the recorder.

Args:
output_dir: The destination folder to save the results.
results_file: The name of the results json file.
config_file: The name of the yaml configuration file.
verbose: Whether to print the session metrics.
"""
self._output_dir = output_dir
self._results_file = results_file
self._config_file = config_file
self._verbose = verbose

self._validation_metrics: List[SESSION_METRICS] = []
self._test_metrics: List[SESSION_METRICS] = []
Expand Down Expand Up @@ -67,13 +93,13 @@ def update(
self._update_validation_metrics(validation_scores)
self._update_test_metrics(test_scores)

def compute(self) -> Dict[str, List[Dict[str, Any]]]:
def compute(self) -> STAGE_RESULTS:
"""Computes and returns the session statistics."""
validation_statistics = list(map(_calculate_statistics, self._validation_metrics))
test_statistics = list(map(_calculate_statistics, self._test_metrics))
return {"val": validation_statistics, "test": test_statistics}

def export(self) -> Dict[str, Any]:
def export(self) -> RESULTS_DICT:
"""Exports the results."""
statistics = self.compute()
return {"metrics": statistics}
Expand All @@ -83,6 +109,8 @@ def save(self) -> None:
results = self.export()
_save_json(results, self.filename)
self._save_config()
if self._verbose:
_print_results(results)

def reset(self) -> None:
"""Resets the state of the tracked metrics."""
Expand Down Expand Up @@ -125,10 +153,10 @@ def _init_session_metrics(n_datasets: int) -> List[SESSION_METRICS]:
return [collections.defaultdict(list) for _ in range(n_datasets)]


def _calculate_statistics(session_metrics: SESSION_METRICS) -> Dict[str, float | List[float]]:
def _calculate_statistics(session_metrics: SESSION_METRICS) -> Dict[str, SESSION_STATISTICS]:
"""Calculate the metric statistics of a dataset session run."""

def _calculate_metric_statistics(values: List[float]) -> Dict[str, float | List[float]]:
def _calculate_metric_statistics(values: List[float]) -> SESSION_STATISTICS:
"""Calculates and returns the metric statistics."""
mean = statistics.mean(values)
stdev = statistics.stdev(values) if len(values) > 1 else 0
Expand All @@ -137,7 +165,7 @@ def _calculate_metric_statistics(values: List[float]) -> Dict[str, float | List[
return dicttoolz.valmap(_calculate_metric_statistics, session_metrics)


def _save_json(data: Dict[str, Any], save_as: str = "data.json"):
def _save_json(data: RESULTS_DICT, save_as: str = "data.json"):
"""Saves data to a json file."""
if not save_as.endswith(".json"):
raise ValueError()
Expand All @@ -146,4 +174,38 @@ def _save_json(data: Dict[str, Any], save_as: str = "data.json"):
fs = cloud_io.get_filesystem(output_dir, anon=False)
fs.makedirs(output_dir, exist_ok=True)
with fs.open(save_as, "w") as file:
json.dump(data, file, indent=4, sort_keys=True)
json.dump(data, file, indent=2, sort_keys=True)


def _print_results(results: RESULTS_DICT) -> None:
"""Prints the results to the console."""
try:
for stage in ["val", "test"]:
for dataset_idx in range(len(results["metrics"][stage])):
_print_table(results["metrics"][stage][dataset_idx], stage, dataset_idx)
except Exception as e:
logger.error(f"Failed to print the results: {e}")


def _print_table(metrics_dict: Dict[str, SESSION_STATISTICS], stage: str, dataset_idx: int):
"""Prints the metrics of a single dataset as a table."""
metrics_table = rich_table.Table(
title=f"\n{stage.capitalize()} Dataset {dataset_idx}", title_style="bold"
)
metrics_table.add_column("Metric", style="cyan")
metrics_table.add_column("Mean", style="magenta")
metrics_table.add_column("Stdev", style="magenta")
metrics_table.add_column("All", style="magenta")

n_runs = len(metrics_dict[next(iter(metrics_dict))]["values"])
for metric_name, metric_dict in metrics_dict.items():
row = [
metric_name,
f'{metric_dict["mean"]:.3f}',
f'{metric_dict["stdev"]:.3f}',
", ".join(f'{metric_dict["values"][i]:.3f}' for i in range(n_runs)),
]
metrics_table.add_row(*row)

console = rich_console.Console()
console.print(metrics_table)
27 changes: 22 additions & 5 deletions src/eva/core/trainers/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def run_evaluation_session(
datamodule: datamodules.DataModule,
*,
n_runs: int = 1,
verbose: bool = True,
) -> None:
"""Runs a downstream evaluation session out-of-place.

Expand All @@ -29,11 +30,17 @@ def run_evaluation_session(
base_model: The base model module to use.
datamodule: The data module.
n_runs: The amount of runs (fit and evaluate) to perform.
verbose: Whether to verbose the session metrics instead of
these of each individual runs and vice-versa.
"""
recorder = _recorder.SessionRecorder(output_dir=base_trainer.default_log_dir)
recorder = _recorder.SessionRecorder(output_dir=base_trainer.default_log_dir, verbose=verbose)
for run_index in range(n_runs):
validation_scores, test_scores = run_evaluation(
base_trainer, base_model, datamodule, run_id=f"run_{run_index}"
base_trainer,
base_model,
datamodule,
run_id=f"run_{run_index}",
verbose=not verbose,
)
recorder.update(validation_scores, test_scores)
recorder.save()
Expand All @@ -45,6 +52,7 @@ def run_evaluation(
datamodule: datamodules.DataModule,
*,
run_id: str | None = None,
verbose: bool = True,
) -> Tuple[_EVALUATE_OUTPUT, _EVALUATE_OUTPUT | None]:
"""Fits and evaluates a model out-of-place.

Expand All @@ -54,19 +62,22 @@ def run_evaluation(
datamodule: The data module.
run_id: The run id to be appended to the output log directory.
If `None`, it will use the log directory of the trainer as is.
verbose: Whether to print the validation and test metrics
in the end of the training.

Returns:
A tuple of with the validation and the test metrics (if exists).
"""
trainer, model = _utils.clone(base_trainer, base_model)
trainer.setup_log_dirs(run_id or "")
return fit_and_validate(trainer, model, datamodule)
return fit_and_validate(trainer, model, datamodule, verbose=verbose)


def fit_and_validate(
trainer: eva_trainer.Trainer,
model: modules.ModelModule,
datamodule: datamodules.DataModule,
verbose: bool = True,
) -> Tuple[_EVALUATE_OUTPUT, _EVALUATE_OUTPUT | None]:
"""Fits and evaluates a model in-place.

Expand All @@ -77,13 +88,19 @@ def fit_and_validate(
trainer: The trainer module to use and update in-place.
model: The model module to use and update in-place.
datamodule: The data module.
verbose: Whether to print the validation and test metrics
in the end of the training.

Returns:
A tuple of with the validation and the test metrics (if exists).
"""
trainer.fit(model, datamodule=datamodule)
validation_scores = trainer.validate(datamodule=datamodule)
test_scores = None if datamodule.datasets.test is None else trainer.test(datamodule=datamodule)
validation_scores = trainer.validate(datamodule=datamodule, verbose=verbose)
test_scores = (
None
if datamodule.datasets.test is None
else trainer.test(datamodule=datamodule, verbose=verbose)
)
return validation_scores, test_scores


Expand Down
1 change: 1 addition & 0 deletions src/eva/core/trainers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,4 +94,5 @@ def run_evaluation_session(
base_model=model,
datamodule=datamodule,
n_runs=self._n_runs,
verbose=self._n_runs > 1,
)