From 49765001eb742ca2a7d450dda7b902c6fb22c262 Mon Sep 17 00:00:00 2001 From: Jennifer Pollack Date: Mon, 15 Jan 2024 17:23:03 +0100 Subject: [PATCH] Added doc strings and removed old shape metrics class --- src/wf_psf/plotting/plots_interface.py | 175 +++++++------------------ 1 file changed, 48 insertions(+), 127 deletions(-) diff --git a/src/wf_psf/plotting/plots_interface.py b/src/wf_psf/plotting/plots_interface.py index f64cd8f3..5b8b8b5f 100644 --- a/src/wf_psf/plotting/plots_interface.py +++ b/src/wf_psf/plotting/plots_interface.py @@ -351,6 +351,12 @@ def __init__(self, plotting_params, metrics, list_of_stars, plots_dir): self.plots_dir = plots_dir def plot(self): + """Plot. + + A generic function to generate plots for the train and test + metrics. + + """ e1_req_euclid = 2e-04 e2_req_euclid = 2e-04 R2_req_euclid = 1e-03 @@ -362,17 +368,40 @@ def plot(self): # Plot for e1 for k, v in metrics_data.items(): - self.make_plot( + self.make_shape_metrics_plot( metrics_data[k]["rmse"], metrics_data[k]["std_rmse"], plot_dataset, k, ) - def prepare_metrics_data( self, plot_dataset, e1_req_euclid, e2_req_euclid, R2_req_euclid ): + """Prepare Metrics Data. + + A function to prepare the metrics data for plotting. + + Parameters + ---------- + plot_dataset: str + A string representing the dataset, i.e. training or test metrics. + + e1_req_euclid: float + A float denoting the Euclid requirement for the `e1` shape metric. + + e2_req_euclid: float + A float denoting the Euclid requirement for the `e2` shape metric. + + R2_req_euclid: float + A float denoting the Euclid requirement for the `R2` shape metric. + + Returns + ------- + shape_metrics_data: dict + A dictionary containing the shape metrics data from a set of runs. + + """ shape_metrics_data = { "e1": {"rmse": [], "std_rmse": []}, "e2": {"rmse": [], "std_rmse": []}, @@ -406,7 +435,23 @@ def prepare_metrics_data( return shape_metrics_data - def make_plot(self, rmse_data, std_rmse_data, plot_dataset, metric): + def make_shape_metrics_plot(self, rmse_data, std_rmse_data, plot_dataset, metric): + """Make Shape Metrics Plot. + + A function to produce plots for the shape metrics. + + Parameters + ---------- + rmse_data: list + A list of dictionaries where each dictionary stores run as the key and the Root Mean Square Error (rmse) relative to the Euclid requirements as the value. + std_rmse_data: list + A list of dictionaries where each dictionary stores run as the key and the Standard Deviation of the Root Mean Square Error (rmse) as the value. + plot_dataset: str + A string denoting whether metrics are for the train or test datasets. + metric: str + A string representing the type of shape metric, i.e., e1, e2, or R2. + + """ make_plot( x_axis=self.list_of_stars, y_axis=rmse_data, @@ -424,130 +469,6 @@ def make_plot(self, rmse_data, std_rmse_data, plot_dataset, metric): ) -class OldShapeMetricsPlotHandler: - """ShapeMetricsPlotHandler class. - - A class to handle plot parameters shape - metrics results. - - Parameters - ---------- - id: str - Class ID name - plotting_params: Recursive Namespace object - Recursive Namespace Object containing plotting parameters - metrics: list - Dictionary containing list of metrics - list_of_stars: list - List containing the number of stars used for each training data set - plots_dir: str - Output directory for metrics plots - - """ - - id = "shape_metrics" - - def __init__(self, plotting_params, metrics, list_of_stars, plots_dir): - self.plotting_params = plotting_params - self.metrics = metrics - self.list_of_stars = list_of_stars - self.plots_dir = plots_dir - - def plot(self): - """Plot. - - A function to generate plots for the train and test - metrics. - - """ - # Define common data - # Common data - e1_req_euclid = 2e-04 - e2_req_euclid = 2e-04 - R2_req_euclid = 1e-03 - for plot_dataset in ["test_metrics", "train_metrics"]: - e1_rmse = [] - e1_std_rmse = [] - e2_rmse = [] - e2_std_rmse = [] - rmse_R2_meanR2 = [] - std_rmse_R2_meanR2 = [] - metrics_id = [] - - for k, v in self.metrics.items(): - for metrics_data in v: - run_id = list(metrics_data.keys())[0] - metrics_id.append(run_id + "-" + k) - - e1_rmse.append( - { - (k + "-" + run_id): metrics_data[run_id][0][plot_dataset][ - "shape_results_dict" - ]["rmse_e1"] - / e1_req_euclid - } - ) - e1_std_rmse.append( - { - (k + "-" + run_id): metrics_data[run_id][0][plot_dataset][ - "shape_results_dict" - ]["std_rmse_e1"] - } - ) - - e2_rmse.append( - { - (k + "-" + run_id): metrics_data[run_id][0][plot_dataset][ - "shape_results_dict" - ]["rmse_e2"] - / e2_req_euclid - } - ) - e2_std_rmse.append( - { - (k + "-" + run_id): metrics_data[run_id][0][plot_dataset][ - "shape_results_dict" - ]["std_rmse_e2"] - } - ) - - rmse_R2_meanR2.append( - { - (k + "-" + run_id): metrics_data[run_id][0][plot_dataset][ - "shape_results_dict" - ]["rmse_R2_meanR2"] - / R2_req_euclid - } - ) - - std_rmse_R2_meanR2.append( - { - (k + "-" + run_id): metrics_data[run_id][0][plot_dataset][ - "shape_results_dict" - ]["std_rmse_R2_meanR2"] - } - ) - - make_plot( - x_axis=self.list_of_stars, - y_axis=e1_rmse, - y_axis_err=e1_std_rmse, - label=metrics_id, - plot_title="Stars " + plot_dataset + ".\nShape RMSE", - x_axis_label="Number of stars", - y_left_axis_label="Absolute error", - y_right_axis_label="Relative error [%]", - filename=os.path.join( - self.plots_dir, - plot_dataset - + "_nstars_" - + "_".join(str(nstar) for nstar in self.list_of_stars) - + "_Shape_RMSE.png", - ), - plot_show=self.plotting_params.plot_show, - ) - - def get_number_of_stars(metrics): """Get Number of Stars.