From 69dafacb4be33e833a9b8917da15bf4a1d33a6c2 Mon Sep 17 00:00:00 2001 From: jeipollack Date: Thu, 9 Nov 2023 13:45:28 +0100 Subject: [PATCH] Correction to the MetricsParamHandler to type change --- src/wf_psf/utils/configs_handler.py | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/wf_psf/utils/configs_handler.py b/src/wf_psf/utils/configs_handler.py index 4202a806..c55409c1 100644 --- a/src/wf_psf/utils/configs_handler.py +++ b/src/wf_psf/utils/configs_handler.py @@ -357,10 +357,12 @@ def call_plot_config_handler_run(self, model_metrics): ] = self.metrics_conf # Update metric results dict with latest result - plots_config_handler.list_of_metrics_dict[self.file_handler.workdir] = { - self.training_conf.training.model_params.model_name - + self.training_conf.training.id_name: [model_metrics] - } + plots_config_handler.list_of_metrics_dict[self.file_handler.workdir] = [ + { + self.training_conf.training.model_params.model_name + + self.training_conf.training.id_name: [model_metrics] + } + ] plots_config_handler.run() @@ -406,8 +408,7 @@ class PlottingConfigHandler: Name of plotting configuration file file_handler: obj An instance of the FileIOHandler class - metrics_conf: dict - A dictionary containing the metrics configuration parameters + """ ids = ("plotting_conf",) @@ -476,13 +477,15 @@ def _metrics_run_id_name(self, wf_outdir, metrics_params): Parameters ---------- + wf_outdir: str + Name of the wf-psf run output directory metrics_params: RecursiveNamespace Object RecursiveNamespace object containing the metrics parameters used to evaluated the trained model. Returns ------- - metrics_run_id_name: str - String containing the model name and id of the training run + metrics_run_id_name: list + List containing the model name and id for each training run """ try: @@ -549,7 +552,7 @@ def load_metrics_into_dict(self): for k, v in self.metrics_confs.items(): run_id_names = self._metrics_run_id_name(k, v) - + metrics_dict[k] = [] for run_id_name in run_id_names: output_path = os.path.join( @@ -564,14 +567,14 @@ def load_metrics_into_dict(self): ) ) try: - metrics_dict[k].append({ - run_id_name: [np.load(output_path, allow_pickle=True)[()]] - }) + metrics_dict[k].append( + {run_id_name: [np.load(output_path, allow_pickle=True)[()]]} + ) except FileNotFoundError: logger.error( "The required file for the plots was not found. Please check your configs settings." ) - + return metrics_dict def run(self):