Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug plotting missing Shape Metrics e2 and R2 #111

Merged
merged 5 commits into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 136 additions & 81 deletions src/wf_psf/plotting/plots_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def make_plot(
x_axis,
y_axis,
y_axis_err,
y2_axis,
label,
plot_title,
x_axis_label,
Expand All @@ -69,6 +70,8 @@ def make_plot(
y-axis values
y_axis_err: list
Error values for y-axis points
y2_axis: list
y2-axis values for right axis
label: str
Label for the points
plot_title: str
Expand Down Expand Up @@ -114,7 +117,7 @@ def make_plot(
kwargs = dict(
linewidth=2, linestyle="dashed", markersize=4, marker="^", alpha=0.5
)
ax2.plot(x_axis[it], y_axis[it][k], **kwargs)
ax2.plot(x_axis[it], y2_axis[it][k], **kwargs)

plt.savefig(filename)

Expand Down Expand Up @@ -158,6 +161,7 @@ def __init__(
metric_name,
rmse,
std_rmse,
rel_rmse,
plot_title,
plots_dir,
):
Expand All @@ -166,6 +170,7 @@ def __init__(
self.metric_name = metric_name
self.rmse = rmse
self.std_rmse = std_rmse
self.rel_rmse = rel_rmse
self.plot_title = plot_title
self.plots_dir = plots_dir
self.list_of_stars = list_of_stars
Expand All @@ -189,6 +194,7 @@ def get_metrics(self, dataset):
"""
rmse = []
std_rmse = []
rel_rmse = []
metrics_id = []
for k, v in self.metrics.items():
for metrics_data in v:
Expand All @@ -210,7 +216,15 @@ def get_metrics(self, dataset):
}
)

return metrics_id, rmse, std_rmse
rel_rmse.append(
{
(k + "-" + run_id): metrics_data[run_id][0][dataset][
self.metric_name
][self.rel_rmse]
}
)

return metrics_id, rmse, std_rmse, rel_rmse

def plot(self):
"""Plot.
Expand All @@ -220,11 +234,12 @@ def plot(self):

"""
for plot_dataset in ["test_metrics", "train_metrics"]:
metrics_id, rmse, std_rmse = self.get_metrics(plot_dataset)
metrics_id, rmse, std_rmse, rel_rmse = self.get_metrics(plot_dataset)
make_plot(
x_axis=self.list_of_stars,
y_axis=rmse,
y_axis_err=std_rmse,
y2_axis=rel_rmse,
label=metrics_id,
plot_title="Stars " + plot_dataset + self.plot_title,
x_axis_label="Number of stars",
Expand Down Expand Up @@ -295,6 +310,7 @@ def plot(self):
for plot_dataset in ["test_metrics", "train_metrics"]:
y_axis = []
y_axis_err = []
y2_axis = []
metrics_id = []

for k, v in self.metrics.items():
Expand All @@ -316,11 +332,19 @@ def plot(self):
]["mono_metric"]["std_rmse_lda"]
}
)
y2_axis.append(
{
(k + "-" + run_id): metrics_data[run_id][0][
plot_dataset
]["mono_metric"]["rel_rmse_lda"]
}
)

make_plot(
x_axis=[lambda_list for _ in range(len(y_axis))],
y_axis=y_axis,
y_axis_err=y_axis_err,
y2_axis=y2_axis,
label=metrics_id,
plot_title="Stars "
+ plot_dataset # type: ignore
Expand All @@ -343,10 +367,8 @@ def plot(self):

class ShapeMetricsPlotHandler:
"""ShapeMetricsPlotHandler class.

A class to handle plot parameters shape
metrics results.

Parameters
----------
id: str
Expand All @@ -359,7 +381,6 @@ class ShapeMetricsPlotHandler:
List containing the number of stars used for each training data set
plots_dir: str
Output directory for metrics plots

"""

id = "shape_metrics"
Expand All @@ -373,96 +394,126 @@ def __init__(self, plotting_params, metrics, list_of_stars, plots_dir):
def plot(self):
"""Plot.

A function to generate plots for the train and test
A generic function to generate plots for the train and test
metrics.

"""
# Define common data
# Common data
e1_req_euclid = 2e-04
e2_req_euclid = 2e-04
R2_req_euclid = 1e-03

for plot_dataset in ["test_metrics", "train_metrics"]:
e1_rmse = []
e1_std_rmse = []
e2_rmse = []
e2_std_rmse = []
rmse_R2_meanR2 = []
std_rmse_R2_meanR2 = []
metrics_id = []
metrics_data = self.prepare_metrics_data(
plot_dataset, e1_req_euclid, e2_req_euclid, R2_req_euclid
)

for k, v in self.metrics.items():
for metrics_data in v:
run_id = list(metrics_data.keys())[0]
metrics_id.append(run_id + "-" + k)

e1_rmse.append(
{
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
]["rmse_e1"]
/ e1_req_euclid
}
)
e1_std_rmse.append(
{
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
]["std_rmse_e1"]
}
# Plot for e1
for k, v in metrics_data.items():
self.make_shape_metrics_plot(
metrics_data[k]["rmse"],
metrics_data[k]["std_rmse"],
metrics_data[k]["rel_rmse"],
plot_dataset,
k,
)

def prepare_metrics_data(
self, plot_dataset, e1_req_euclid, e2_req_euclid, R2_req_euclid
):
"""Prepare Metrics Data.

A function to prepare the metrics data for plotting.

Parameters
----------
plot_dataset: str
A string representing the dataset, i.e. training or test metrics.
e1_req_euclid: float
A float denoting the Euclid requirement for the `e1` shape metric.
e2_req_euclid: float
A float denoting the Euclid requirement for the `e2` shape metric.
R2_req_euclid: float
A float denoting the Euclid requirement for the `R2` shape metric.

Returns
-------
shape_metrics_data: dict
A dictionary containing the shape metrics data from a set of runs.

"""
shape_metrics_data = {
"e1": {"rmse": [], "std_rmse": [], "rel_rmse": []},
"e2": {"rmse": [], "std_rmse": [], "rel_rmse": []},
"R2_meanR2": {"rmse": [], "std_rmse": [], "rel_rmse": []},
}

for k, v in self.metrics.items():
for metrics_data in v:
run_id = list(metrics_data.keys())[0]

for metric in ["e1", "e2", "R2_meanR2"]:
metric_rmse = metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
][f"rmse_{metric}"]
metric_std_rmse = metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
][f"std_rmse_{metric}"]

relative_metric_rmse = metric_rmse / (
e1_req_euclid
if metric == "e1"
else (e2_req_euclid if metric == "e2" else R2_req_euclid)
)

e2_rmse.append(
{
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
]["rmse_e2"]
/ e2_req_euclid
}
shape_metrics_data[metric]["rmse"].append(
{f"{k}-{run_id}": metric_rmse}
)
e2_std_rmse.append(
{
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
]["std_rmse_e2"]
}
shape_metrics_data[metric]["std_rmse"].append(
{f"{k}-{run_id}": metric_std_rmse}
)

rmse_R2_meanR2.append(
{
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
]["rmse_R2_meanR2"]
/ R2_req_euclid
}
shape_metrics_data[metric]["rel_rmse"].append(
{f"{k}-{run_id}": relative_metric_rmse}
)

std_rmse_R2_meanR2.append(
{
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
]["std_rmse_R2_meanR2"]
}
)
return shape_metrics_data

make_plot(
x_axis=self.list_of_stars,
y_axis=e1_rmse,
y_axis_err=e1_std_rmse,
label=metrics_id,
plot_title="Stars " + plot_dataset + ".\nShape RMSE",
x_axis_label="Number of stars",
y_left_axis_label="Absolute error",
y_right_axis_label="Relative error [%]",
filename=os.path.join(
self.plots_dir,
plot_dataset
+ "_nstars_"
+ "_".join(str(nstar) for nstar in self.list_of_stars)
+ "_Shape_RMSE.png",
),
plot_show=self.plotting_params.plot_show,
)
def make_shape_metrics_plot(
self, rmse_data, std_rmse_data, rel_rmse_data, plot_dataset, metric
):
"""Make Shape Metrics Plot.

A function to produce plots for the shape metrics.

Parameters
----------
rmse_data: list
A list of dictionaries where each dictionary stores run as the key and the Root Mean Square Error (rmse).
std_rmse_data: list
A list of dictionaries where each dictionary stores run as the key and the Standard Deviation of the Root Mean Square Error (rmse) as the value.
rel_rmse_data: list
A list of dictionaries where each dictionary stores run as the key and the Root Mean Square Error (rmse) relative to the Euclid requirements as the value.
plot_dataset: str
A string denoting whether metrics are for the train or test datasets.
metric: str
A string representing the type of shape metric, i.e., e1, e2, or R2.

"""
make_plot(
x_axis=self.list_of_stars,
y_axis=rmse_data,
y_axis_err=std_rmse_data,
y2_axis=rel_rmse_data,
label=[key for item in rmse_data for key in item],
plot_title=f"Stars {plot_dataset}. Shape {metric.upper()} RMSE",
x_axis_label="Number of stars",
y_left_axis_label="Absolute error",
y_right_axis_label="Relative error [%]",
filename=os.path.join(
self.plots_dir,
f"{plot_dataset}_nstars_{'_'.join(str(nstar) for nstar in self.list_of_stars)}_Shape_{metric.upper()}_RMSE.png",
),
plot_show=self.plotting_params.plot_show,
)


def get_number_of_stars(metrics):
Expand Down Expand Up @@ -509,16 +560,19 @@ def plot_metrics(plotting_params, list_of_metrics, metrics_confs, plot_saving_pa
"poly_metric": {
"rmse": "rmse",
"std_rmse": "std_rmse",
"rel_rmse": "rel_rmse",
"plot_title": ".\nPolychromatic pixel RMSE @ Euclid resolution",
},
"opd_metric": {
"rmse": "rmse_opd",
"std_rmse": "rmse_std_opd",
"rel_rmse": "rel_rmse_opd",
"plot_title": ".\nOPD RMSE",
},
"shape_results_dict": {
"rmse": "pix_rmse",
"std_rmse": "pix_rmse_std",
"rel_rmse": "rel_pix_rmse",
"plot_title": "\nPixel RMSE @ 3x Euclid resolution",
},
}
Expand All @@ -533,6 +587,7 @@ def plot_metrics(plotting_params, list_of_metrics, metrics_confs, plot_saving_pa
k,
v["rmse"],
v["std_rmse"],
v["rel_rmse"],
v["plot_title"],
plot_saving_path,
)
Expand Down
8 changes: 4 additions & 4 deletions src/wf_psf/tests/metrics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def main_metrics(training_params):
return np.load(os.path.join(main_dir, metrics_filename), allow_pickle=True)[()]


@pytest.mark.skip(reason="Requires gpu")
@pytest.mark.skipif("GITHUB_ENV" in os.environ, reason="Skipping GPU tests in CI")
def test_eval_metrics_polychromatic_lowres(
training_params,
weights_path_basename,
Expand Down Expand Up @@ -156,7 +156,7 @@ def test_eval_metrics_polychromatic_lowres(
assert ratio_rel_std_rmse < tol


@pytest.mark.skip(reason="Requires gpu")
@pytest.mark.skipif("GITHUB_ENV" in os.environ, reason="Skipping GPU tests in CI")
def test_evaluate_metrics_opd(
training_params,
weights_path_basename,
Expand Down Expand Up @@ -206,7 +206,7 @@ def test_evaluate_metrics_opd(
assert ratio_rel_rmse_std_opd < tol


@pytest.mark.skip(reason="Requires gpu")
@pytest.mark.skipif("GITHUB_ENV" in os.environ, reason="Skipping GPU tests in CI")
def test_eval_metrics_mono_rmse(
training_params,
weights_path_basename,
Expand Down Expand Up @@ -271,7 +271,7 @@ def test_eval_metrics_mono_rmse(
assert ratio_rel_rmse_std_mono < tol


@pytest.mark.skip(reason="Requires gpu")
@pytest.mark.skipif("GITHUB_ENV" in os.environ, reason="Skipping GPU tests in CI")
def test_evaluate_metrics_shape(
training_params,
weights_path_basename,
Expand Down
Loading
Loading