Skip to content

Commit

Permalink
Added command to plot e2 & R2
Browse files Browse the repository at this point in the history
  • Loading branch information
Jennifer Pollack authored and jeipollack committed Mar 5, 2024
1 parent 6135a08 commit 8665e11
Showing 1 changed file with 83 additions and 0 deletions.
83 changes: 83 additions & 0 deletions src/wf_psf/plotting/plots_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,89 @@ def plot(self):


class ShapeMetricsPlotHandler:
id = "shape_metrics"

def __init__(self, plotting_params, metrics, list_of_stars, plots_dir):
self.plotting_params = plotting_params
self.metrics = metrics
self.list_of_stars = list_of_stars
self.plots_dir = plots_dir

def plot(self):
e1_req_euclid = 2e-04
e2_req_euclid = 2e-04
R2_req_euclid = 1e-03

for plot_dataset in ["test_metrics", "train_metrics"]:
metrics_data = self.prepare_metrics_data(
plot_dataset, e1_req_euclid, e2_req_euclid, R2_req_euclid
)

# Plot for e1
for k, v in metrics_data.items():
self.make_plot(
metrics_data[k]["rmse"],
metrics_data[k]["std_rmse"],
plot_dataset,
k,
)


def prepare_metrics_data(
self, plot_dataset, e1_req_euclid, e2_req_euclid, R2_req_euclid
):
shape_metrics_data = {
"e1": {"rmse": [], "std_rmse": []},
"e2": {"rmse": [], "std_rmse": []},
"R2_meanR2": {"rmse": [], "std_rmse": []},
}

for k, v in self.metrics.items():
for metrics_data in v:
run_id = list(metrics_data.keys())[0]

for metric in ["e1", "e2", "R2_meanR2"]:
metric_rmse = metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
][f"rmse_{metric}"]
metric_std_rmse = metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
][f"std_rmse_{metric}"]

relative_metric_rmse = metric_rmse / (
e1_req_euclid
if metric == "e1"
else (e2_req_euclid if metric == "e2" else R2_req_euclid)
)

shape_metrics_data[metric]["rmse"].append(
{f"{k}-{run_id}": relative_metric_rmse}
)
shape_metrics_data[metric]["std_rmse"].append(
{f"{k}-{run_id}": metric_std_rmse}
)

return shape_metrics_data

def make_plot(self, rmse_data, std_rmse_data, plot_dataset, metric):
make_plot(
x_axis=self.list_of_stars,
y_axis=rmse_data,
y_axis_err=std_rmse_data,
label=[key for item in rmse_data for key in item],
plot_title=f"Stars {plot_dataset}. Shape {metric.upper()} RMSE",
x_axis_label="Number of stars",
y_left_axis_label="Absolute error",
y_right_axis_label="Relative error [%]",
filename=os.path.join(
self.plots_dir,
f"{plot_dataset}_nstars_{'_'.join(str(nstar) for nstar in self.list_of_stars)}_Shape_{metric.upper()}_RMSE.png",
),
plot_show=self.plotting_params.plot_show,
)


class OldShapeMetricsPlotHandler:
"""ShapeMetricsPlotHandler class.
A class to handle plot parameters shape
Expand Down

0 comments on commit 8665e11

Please sign in to comment.