Skip to content

Commit

Permalink
Added missing rel_rmse to right-axis in metric plots
Browse files Browse the repository at this point in the history
  • Loading branch information
jeipollack committed Mar 5, 2024
1 parent 4976500 commit fe44792
Showing 1 changed file with 46 additions and 12 deletions.
58 changes: 46 additions & 12 deletions src/wf_psf/plotting/plots_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def make_plot(
x_axis,
y_axis,
y_axis_err,
y2_axis,
label,
plot_title,
x_axis_label,
Expand All @@ -69,6 +70,8 @@ def make_plot(
y-axis values
y_axis_err: list
Error values for y-axis points
y2_axis: list
y2-axis values for right axis
label: str
Label for the points
plot_title: str
Expand Down Expand Up @@ -114,7 +117,7 @@ def make_plot(
kwargs = dict(
linewidth=2, linestyle="dashed", markersize=4, marker="^", alpha=0.5
)
ax2.plot(x_axis[it], y_axis[it][k], **kwargs)
ax2.plot(x_axis[it], y2_axis[it][k], **kwargs)

plt.savefig(filename)

Expand Down Expand Up @@ -158,6 +161,7 @@ def __init__(
metric_name,
rmse,
std_rmse,
rel_rmse,
plot_title,
plots_dir,
):
Expand All @@ -166,6 +170,7 @@ def __init__(
self.metric_name = metric_name
self.rmse = rmse
self.std_rmse = std_rmse
self.rel_rmse = rel_rmse
self.plot_title = plot_title
self.plots_dir = plots_dir
self.list_of_stars = list_of_stars
Expand All @@ -189,6 +194,7 @@ def get_metrics(self, dataset):
"""
rmse = []
std_rmse = []
rel_rmse = []
metrics_id = []
for k, v in self.metrics.items():
for metrics_data in v:
Expand All @@ -210,7 +216,15 @@ def get_metrics(self, dataset):
}
)

return metrics_id, rmse, std_rmse
rel_rmse.append(
{
(k + "-" + run_id): metrics_data[run_id][0][dataset][
self.metric_name
][self.rel_rmse]
}
)

return metrics_id, rmse, std_rmse, rel_rmse

def plot(self):
"""Plot.
Expand All @@ -220,11 +234,12 @@ def plot(self):
"""
for plot_dataset in ["test_metrics", "train_metrics"]:
metrics_id, rmse, std_rmse = self.get_metrics(plot_dataset)
metrics_id, rmse, std_rmse, rel_rmse = self.get_metrics(plot_dataset)
make_plot(
x_axis=self.list_of_stars,
y_axis=rmse,
y_axis_err=std_rmse,
y2_axis=rel_rmse,
label=metrics_id,
plot_title="Stars " + plot_dataset + self.plot_title,
x_axis_label="Number of stars",
Expand Down Expand Up @@ -295,6 +310,7 @@ def plot(self):
for plot_dataset in ["test_metrics", "train_metrics"]:
y_axis = []
y_axis_err = []
y2_axis = []
metrics_id = []

for k, v in self.metrics.items():
Expand All @@ -316,11 +332,19 @@ def plot(self):
]["mono_metric"]["std_rmse_lda"]
}
)
y2_axis.append(
{
(k + "-" + run_id): metrics_data[run_id][0][
plot_dataset
]["mono_metric"]["rel_rmse_lda"]
}
)

make_plot(
x_axis=[lambda_list for _ in range(len(y_axis))],
y_axis=y_axis,
y_axis_err=y_axis_err,
y2_axis=y2_axis,
label=metrics_id,
plot_title="Stars "
+ plot_dataset # type: ignore
Expand Down Expand Up @@ -371,6 +395,7 @@ def plot(self):
self.make_shape_metrics_plot(
metrics_data[k]["rmse"],
metrics_data[k]["std_rmse"],
metrics_data[k]["rel_rmse"],
plot_dataset,
k,
)
Expand All @@ -386,13 +411,10 @@ def prepare_metrics_data(
----------
plot_dataset: str
A string representing the dataset, i.e. training or test metrics.
e1_req_euclid: float
A float denoting the Euclid requirement for the `e1` shape metric.
e2_req_euclid: float
A float denoting the Euclid requirement for the `e2` shape metric.
R2_req_euclid: float
A float denoting the Euclid requirement for the `R2` shape metric.
Expand All @@ -403,9 +425,9 @@ def prepare_metrics_data(
"""
shape_metrics_data = {
"e1": {"rmse": [], "std_rmse": []},
"e2": {"rmse": [], "std_rmse": []},
"R2_meanR2": {"rmse": [], "std_rmse": []},
"e1": {"rmse": [], "std_rmse": [], "rel_rmse": []},
"e2": {"rmse": [], "std_rmse": [], "rel_rmse": []},
"R2_meanR2": {"rmse": [], "std_rmse": [], "rel_rmse": []},
}

for k, v in self.metrics.items():
Expand All @@ -427,25 +449,32 @@ def prepare_metrics_data(
)

shape_metrics_data[metric]["rmse"].append(
{f"{k}-{run_id}": relative_metric_rmse}
{f"{k}-{run_id}": metric_rmse}
)
shape_metrics_data[metric]["std_rmse"].append(
{f"{k}-{run_id}": metric_std_rmse}
)
shape_metrics_data[metric]["rel_rmse"].append(
{f"{k}-{run_id}": relative_metric_rmse}
)

return shape_metrics_data

def make_shape_metrics_plot(self, rmse_data, std_rmse_data, plot_dataset, metric):
def make_shape_metrics_plot(
self, rmse_data, std_rmse_data, rel_rmse_data, plot_dataset, metric
):
"""Make Shape Metrics Plot.
A function to produce plots for the shape metrics.
Parameters
----------
rmse_data: list
A list of dictionaries where each dictionary stores run as the key and the Root Mean Square Error (rmse) relative to the Euclid requirements as the value.
A list of dictionaries where each dictionary stores run as the key and the Root Mean Square Error (rmse).
std_rmse_data: list
A list of dictionaries where each dictionary stores run as the key and the Standard Deviation of the Root Mean Square Error (rmse) as the value.
rel_rmse_data: list
A list of dictionaries where each dictionary stores run as the key and the Root Mean Square Error (rmse) relative to the Euclid requirements as the value.
plot_dataset: str
A string denoting whether metrics are for the train or test datasets.
metric: str
Expand All @@ -456,6 +485,7 @@ def make_shape_metrics_plot(self, rmse_data, std_rmse_data, plot_dataset, metric
x_axis=self.list_of_stars,
y_axis=rmse_data,
y_axis_err=std_rmse_data,
y2_axis=rel_rmse_data,
label=[key for item in rmse_data for key in item],
plot_title=f"Stars {plot_dataset}. Shape {metric.upper()} RMSE",
x_axis_label="Number of stars",
Expand Down Expand Up @@ -513,16 +543,19 @@ def plot_metrics(plotting_params, list_of_metrics, metrics_confs, plot_saving_pa
"poly_metric": {
"rmse": "rmse",
"std_rmse": "std_rmse",
"rel_rmse": "rel_rmse",
"plot_title": ".\nPolychromatic pixel RMSE @ Euclid resolution",
},
"opd_metric": {
"rmse": "rmse_opd",
"std_rmse": "rmse_std_opd",
"rel_rmse": "rel_rmse_opd",
"plot_title": ".\nOPD RMSE",
},
"shape_results_dict": {
"rmse": "pix_rmse",
"std_rmse": "pix_rmse_std",
"rel_rmse": "rel_pix_rmse",
"plot_title": "\nPixel RMSE @ 3x Euclid resolution",
},
}
Expand All @@ -537,6 +570,7 @@ def plot_metrics(plotting_params, list_of_metrics, metrics_confs, plot_saving_pa
k,
v["rmse"],
v["std_rmse"],
v["rel_rmse"],
v["plot_title"],
plot_saving_path,
)
Expand Down

0 comments on commit fe44792

Please sign in to comment.