Skip to content

Commit

Permalink
Added doc strings and removed old shape metrics class
Browse files Browse the repository at this point in the history
  • Loading branch information
Jennifer Pollack authored and jeipollack committed Mar 5, 2024
1 parent 8665e11 commit 4976500
Showing 1 changed file with 48 additions and 127 deletions.
175 changes: 48 additions & 127 deletions src/wf_psf/plotting/plots_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,12 @@ def __init__(self, plotting_params, metrics, list_of_stars, plots_dir):
self.plots_dir = plots_dir

def plot(self):
"""Plot.
A generic function to generate plots for the train and test
metrics.
"""
e1_req_euclid = 2e-04
e2_req_euclid = 2e-04
R2_req_euclid = 1e-03
Expand All @@ -362,17 +368,40 @@ def plot(self):

# Plot for e1
for k, v in metrics_data.items():
self.make_plot(
self.make_shape_metrics_plot(
metrics_data[k]["rmse"],
metrics_data[k]["std_rmse"],
plot_dataset,
k,
)


def prepare_metrics_data(
self, plot_dataset, e1_req_euclid, e2_req_euclid, R2_req_euclid
):
"""Prepare Metrics Data.
A function to prepare the metrics data for plotting.
Parameters
----------
plot_dataset: str
A string representing the dataset, i.e. training or test metrics.
e1_req_euclid: float
A float denoting the Euclid requirement for the `e1` shape metric.
e2_req_euclid: float
A float denoting the Euclid requirement for the `e2` shape metric.
R2_req_euclid: float
A float denoting the Euclid requirement for the `R2` shape metric.
Returns
-------
shape_metrics_data: dict
A dictionary containing the shape metrics data from a set of runs.
"""
shape_metrics_data = {
"e1": {"rmse": [], "std_rmse": []},
"e2": {"rmse": [], "std_rmse": []},
Expand Down Expand Up @@ -406,7 +435,23 @@ def prepare_metrics_data(

return shape_metrics_data

def make_plot(self, rmse_data, std_rmse_data, plot_dataset, metric):
def make_shape_metrics_plot(self, rmse_data, std_rmse_data, plot_dataset, metric):
"""Make Shape Metrics Plot.
A function to produce plots for the shape metrics.
Parameters
----------
rmse_data: list
A list of dictionaries where each dictionary stores run as the key and the Root Mean Square Error (rmse) relative to the Euclid requirements as the value.
std_rmse_data: list
A list of dictionaries where each dictionary stores run as the key and the Standard Deviation of the Root Mean Square Error (rmse) as the value.
plot_dataset: str
A string denoting whether metrics are for the train or test datasets.
metric: str
A string representing the type of shape metric, i.e., e1, e2, or R2.
"""
make_plot(
x_axis=self.list_of_stars,
y_axis=rmse_data,
Expand All @@ -424,130 +469,6 @@ def make_plot(self, rmse_data, std_rmse_data, plot_dataset, metric):
)


class OldShapeMetricsPlotHandler:
"""ShapeMetricsPlotHandler class.
A class to handle plot parameters shape
metrics results.
Parameters
----------
id: str
Class ID name
plotting_params: Recursive Namespace object
Recursive Namespace Object containing plotting parameters
metrics: list
Dictionary containing list of metrics
list_of_stars: list
List containing the number of stars used for each training data set
plots_dir: str
Output directory for metrics plots
"""

id = "shape_metrics"

def __init__(self, plotting_params, metrics, list_of_stars, plots_dir):
self.plotting_params = plotting_params
self.metrics = metrics
self.list_of_stars = list_of_stars
self.plots_dir = plots_dir

def plot(self):
"""Plot.
A function to generate plots for the train and test
metrics.
"""
# Define common data
# Common data
e1_req_euclid = 2e-04
e2_req_euclid = 2e-04
R2_req_euclid = 1e-03
for plot_dataset in ["test_metrics", "train_metrics"]:
e1_rmse = []
e1_std_rmse = []
e2_rmse = []
e2_std_rmse = []
rmse_R2_meanR2 = []
std_rmse_R2_meanR2 = []
metrics_id = []

for k, v in self.metrics.items():
for metrics_data in v:
run_id = list(metrics_data.keys())[0]
metrics_id.append(run_id + "-" + k)

e1_rmse.append(
{
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
]["rmse_e1"]
/ e1_req_euclid
}
)
e1_std_rmse.append(
{
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
]["std_rmse_e1"]
}
)

e2_rmse.append(
{
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
]["rmse_e2"]
/ e2_req_euclid
}
)
e2_std_rmse.append(
{
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
]["std_rmse_e2"]
}
)

rmse_R2_meanR2.append(
{
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
]["rmse_R2_meanR2"]
/ R2_req_euclid
}
)

std_rmse_R2_meanR2.append(
{
(k + "-" + run_id): metrics_data[run_id][0][plot_dataset][
"shape_results_dict"
]["std_rmse_R2_meanR2"]
}
)

make_plot(
x_axis=self.list_of_stars,
y_axis=e1_rmse,
y_axis_err=e1_std_rmse,
label=metrics_id,
plot_title="Stars " + plot_dataset + ".\nShape RMSE",
x_axis_label="Number of stars",
y_left_axis_label="Absolute error",
y_right_axis_label="Relative error [%]",
filename=os.path.join(
self.plots_dir,
plot_dataset
+ "_nstars_"
+ "_".join(str(nstar) for nstar in self.list_of_stars)
+ "_Shape_RMSE.png",
),
plot_show=self.plotting_params.plot_show,
)


def get_number_of_stars(metrics):
"""Get Number of Stars.
Expand Down

0 comments on commit 4976500

Please sign in to comment.