Skip to content

Commit

Permalink
Correction to the MetricsParamHandler to type change
Browse files Browse the repository at this point in the history
  • Loading branch information
jeipollack committed Nov 9, 2023
1 parent d2d1a32 commit 69dafac
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions src/wf_psf/utils/configs_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,12 @@ def call_plot_config_handler_run(self, model_metrics):
] = self.metrics_conf

# Update metric results dict with latest result
plots_config_handler.list_of_metrics_dict[self.file_handler.workdir] = {
self.training_conf.training.model_params.model_name
+ self.training_conf.training.id_name: [model_metrics]
}
plots_config_handler.list_of_metrics_dict[self.file_handler.workdir] = [
{
self.training_conf.training.model_params.model_name
+ self.training_conf.training.id_name: [model_metrics]
}
]

plots_config_handler.run()

Expand Down Expand Up @@ -406,8 +408,7 @@ class PlottingConfigHandler:
Name of plotting configuration file
file_handler: obj
An instance of the FileIOHandler class
metrics_conf: dict
A dictionary containing the metrics configuration parameters
"""

ids = ("plotting_conf",)
Expand Down Expand Up @@ -476,13 +477,15 @@ def _metrics_run_id_name(self, wf_outdir, metrics_params):
Parameters
----------
wf_outdir: str
Name of the wf-psf run output directory
metrics_params: RecursiveNamespace Object
RecursiveNamespace object containing the metrics parameters used to evaluated the trained model.
Returns
-------
metrics_run_id_name: str
String containing the model name and id of the training run
metrics_run_id_name: list
List containing the model name and id for each training run
"""

try:
Expand Down Expand Up @@ -549,7 +552,7 @@ def load_metrics_into_dict(self):

for k, v in self.metrics_confs.items():
run_id_names = self._metrics_run_id_name(k, v)

metrics_dict[k] = []
for run_id_name in run_id_names:
output_path = os.path.join(
Expand All @@ -564,14 +567,14 @@ def load_metrics_into_dict(self):
)
)
try:
metrics_dict[k].append({
run_id_name: [np.load(output_path, allow_pickle=True)[()]]
})
metrics_dict[k].append(
{run_id_name: [np.load(output_path, allow_pickle=True)[()]]}
)
except FileNotFoundError:
logger.error(
"The required file for the plots was not found. Please check your configs settings."
)

return metrics_dict

def run(self):
Expand Down

0 comments on commit 69dafac

Please sign in to comment.