diff --git a/src/wf_psf/plotting/plots_interface.py b/src/wf_psf/plotting/plots_interface.py index 30e22674..f64cd8f3 100644 --- a/src/wf_psf/plotting/plots_interface.py +++ b/src/wf_psf/plotting/plots_interface.py @@ -342,6 +342,89 @@ def plot(self): class ShapeMetricsPlotHandler: + id = "shape_metrics" + + def __init__(self, plotting_params, metrics, list_of_stars, plots_dir): + self.plotting_params = plotting_params + self.metrics = metrics + self.list_of_stars = list_of_stars + self.plots_dir = plots_dir + + def plot(self): + e1_req_euclid = 2e-04 + e2_req_euclid = 2e-04 + R2_req_euclid = 1e-03 + + for plot_dataset in ["test_metrics", "train_metrics"]: + metrics_data = self.prepare_metrics_data( + plot_dataset, e1_req_euclid, e2_req_euclid, R2_req_euclid + ) + + # Plot for e1 + for k, v in metrics_data.items(): + self.make_plot( + metrics_data[k]["rmse"], + metrics_data[k]["std_rmse"], + plot_dataset, + k, + ) + + + def prepare_metrics_data( + self, plot_dataset, e1_req_euclid, e2_req_euclid, R2_req_euclid + ): + shape_metrics_data = { + "e1": {"rmse": [], "std_rmse": []}, + "e2": {"rmse": [], "std_rmse": []}, + "R2_meanR2": {"rmse": [], "std_rmse": []}, + } + + for k, v in self.metrics.items(): + for metrics_data in v: + run_id = list(metrics_data.keys())[0] + + for metric in ["e1", "e2", "R2_meanR2"]: + metric_rmse = metrics_data[run_id][0][plot_dataset][ + "shape_results_dict" + ][f"rmse_{metric}"] + metric_std_rmse = metrics_data[run_id][0][plot_dataset][ + "shape_results_dict" + ][f"std_rmse_{metric}"] + + relative_metric_rmse = metric_rmse / ( + e1_req_euclid + if metric == "e1" + else (e2_req_euclid if metric == "e2" else R2_req_euclid) + ) + + shape_metrics_data[metric]["rmse"].append( + {f"{k}-{run_id}": relative_metric_rmse} + ) + shape_metrics_data[metric]["std_rmse"].append( + {f"{k}-{run_id}": metric_std_rmse} + ) + + return shape_metrics_data + + def make_plot(self, rmse_data, std_rmse_data, plot_dataset, metric): + make_plot( + x_axis=self.list_of_stars, + y_axis=rmse_data, + y_axis_err=std_rmse_data, + label=[key for item in rmse_data for key in item], + plot_title=f"Stars {plot_dataset}. Shape {metric.upper()} RMSE", + x_axis_label="Number of stars", + y_left_axis_label="Absolute error", + y_right_axis_label="Relative error [%]", + filename=os.path.join( + self.plots_dir, + f"{plot_dataset}_nstars_{'_'.join(str(nstar) for nstar in self.list_of_stars)}_Shape_{metric.upper()}_RMSE.png", + ), + plot_show=self.plotting_params.plot_show, + ) + + +class OldShapeMetricsPlotHandler: """ShapeMetricsPlotHandler class. A class to handle plot parameters shape