Skip to content

Commit

Permalink
Merge pull request #85 from CosmoStat/feature_unit_tests_run_configs
Browse files Browse the repository at this point in the history
Feature unit tests run configs - MetricsConfigHandler Class
  • Loading branch information
sfarrens authored Nov 16, 2023
2 parents 3dae190 + 52214a0 commit 997b322
Show file tree
Hide file tree
Showing 11 changed files with 514 additions and 87 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ htmlcov/
.coverage.*
.cache
nosetests.xml
pytest.xml
coverage.xml
*.cover
*.py,cover
Expand Down
41 changes: 36 additions & 5 deletions src/wf_psf/psf_models/psf_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tensorflow.python.keras.engine import data_adapter
from wf_psf.utils.utils import PI_zernikes, zernike_generator
from wf_psf.sims.SimPSFToolkit import SimPSFToolkit
import glob
from sys import exit
import logging

Expand All @@ -21,7 +22,13 @@


class PsfModelError(Exception):
pass
"""PSF Model Parameter Error exception class for specific error scenarios."""

def __init__(
self, message="An error with your PSF model parameter settings occurred."
):
self.message = message
super().__init__(self.message)


def register_psfclass(psf_class):
Expand Down Expand Up @@ -67,10 +74,9 @@ def set_psf_model(model_name):

try:
psf_class = PSF_CLASS[model_name]
except KeyError:
logger.exception("PSF model entered is invalid. Check your config settings.")
exit()

except KeyError as e:
logger.exception(e)
raise PsfModelError("PSF model entered is invalid. Check your config settings.")
return psf_class


Expand Down Expand Up @@ -102,6 +108,31 @@ def get_psf_model(model_params, training_hparams, *coeff_matrix):
return psf_class(model_params, training_hparams, *coeff_matrix)


def get_psf_model_weights_filepath(weights_filepath):
"""Get PSF model weights filepath.
A function to return the basename of the user-specified psf model weights path.
Parameters
----------
weights_filepath: str
Basename of the psf model weights to be loaded.
Returns
-------
str
The absolute path concatenated to the basename of the psf model weights to be loaded.
"""
try:
return glob.glob(weights_filepath)[0].split(".")[0]
except IndexError:
logger.exception(
"PSF weights file not found. Check that you've specified the correct weights file in the metrics config file."
)
raise PsfModelError("PSF model weights error.")


def tf_zernike_cube(n_zernikes, pupil_diam):
"""Tensor Flow Zernike Cube.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
---
metrics_conf: metrics_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
metrics:
# Specify the type of model weights to load by entering "psf_model" to load weights of final psf model or "checkpoint" to load weights from a checkpoint callback.
model_save_path: checkpoint
# Choose the training cycle for which to evaluate the psf_model. Can be: 1, 2, ...
saved_training_cycle: 2
# Metrics-only run: Specify model_params for a pre-trained model else leave blank if running training + metrics
# Specify path to Parent Directory of Trained Model
trained_model_path: src/wf_psf/tests/data/validation/main_random_seed
# Name of the Trained Model Config file stored in config sub-directory in the trained_model_path parent directory
trained_model_config: training_config.yaml
#Evaluate the monchromatic RMSE metric.
eval_mono_metric_rmse: True
#Evaluate the OPD RMSE metric.
eval_opd_metric_rmse: True
#Evaluate the super-resolution and the shape RMSE metrics for the train dataset.
eval_train_shape_sr_metric_rmse: True
# Name of Plotting Config file - Enter name of yaml file to run plot metrics else if empty run metrics evaluation only
plotting_config: <enter name of plotting_config .yaml file or leave empty>
ground_truth_model:
model_params:
#Model used as ground truth for the evaluation. Options are: 'poly' for polychromatic and 'physical' [not available].
model_name: poly

# Evaluation parameters
#Number of bins used for the ground truth model poly PSF generation
n_bins_lda: 20

#Downsampling rate to match the oversampled model to the specified telescope's sampling.
output_Q: 3

#Oversampling rate used for the OPD/WFE PSF model.
oversampling_rate: 3

#Dimension of the pixel PSF postage stamp
output_dim: 32

#Dimension of the OPD/Wavefront space."
pupil_diameter: 256

#Boolean to define if we use sample weights based on the noise standard deviation estimation
use_sample_weights: True

#Interpolation type for the physical poly model. Options are: 'none', 'all', 'top_K', 'independent_Zk'."
interpolation_type: None

# SED intepolation points per bin
sed_interp_pts_per_bin: 0

# SED extrapolate
sed_extrapolate: True

# SED interpolate kind
sed_interp_kind: linear

# Standard deviation of the multiplicative SED Gaussian noise.
sed_sigma: 0

#Limits of the PSF field coordinates for the x axis.
x_lims: [0.0, 1.0e+3]

#Limits of the PSF field coordinates for the y axis.
y_lims: [0.0, 1.0e+3]

# Hyperparameters for Parametric model
param_hparams:
# Random seed for Tensor Flow Initialization
random_seed: 3877572

# Parameter for the l2 loss function for the Optical path differences (OPD)/WFE
l2_param: 0.

#Zernike polynomial modes to use on the parametric part.
n_zernikes: 45

#Max polynomial degree of the parametric part.
d_max: 2

#Flag to save optimisation history for parametric model
save_optim_history_param: true

# Hyperparameters for non-parametric model
nonparam_hparams:
#Max polynomial degree of the non-parametric part.
d_max_nonparam: 5

# Number of graph features
num_graph_features: 10

#L1 regularisation parameter for the non-parametric part."
l1_rate: 1.0e-8

#Flag to enable Projected learning for DD_features to be used with `poly` or `semiparametric` model.
project_dd_features: False

#Flag to reset DD_features to be used with `poly` or `semiparametric` model
reset_dd_features: False

#Flag to save optimisation history for non-parametric model
save_optim_history_nonparam: True

metrics_hparams:
# Batch size to use for the evaluation.
batch_size: 16

#Save RMS error for each super resolved PSF in the test dataset in addition to the mean across the FOV."
#Flag to get Super-Resolution pixel PSF RMSE for each individual test star.
#If `True`, the relative pixel RMSE of each star is added to ther saving dictionary.
opt_stars_rel_pix_rmse: False

## Specific parameters
# Parameter for the l2 loss of the OPD.
l2_param: 0.

## Define the resolution at which you'd like to measure the shape of the PSFs
#Downsampling rate from the high-resolution pixel modelling space.
# Recommended value: 1
output_Q: 1

#Dimension of the pixel PSF postage stamp; it should be big enough so that most of the signal is contained inside the postage stamp.
# It also depends on the Q values used.
# Recommended value: 64 or higher
output_dim: 64

Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
training:
# ID name
id_name: _validation
id_name: _sample_w_bis1_2k
# Name of Data Config file
data_config: data_config.yaml
# Metrics Config file - Enter file to run metrics evaluation else if empty run train only
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
training:
# ID name
id_name: _errorsample_w_bis1_2k
# Name of Data Config file
data_config: data_config.yaml
# Metrics Config file - Enter file to run metrics evaluation else if empty run train only
metrics_config: metrics_config.yaml
model_params:
# Model type. Options are: 'mccd', 'graph', 'poly, 'param', 'poly_physical'."
model_name: poly

#Num of wavelength bins to reconstruct polychromatic objects.
n_bins_lda: 8

#Downsampling rate to match the oversampled model to the specified telescope's sampling.
output_Q: 3

#Oversampling rate used for the OPD/WFE PSF model.
oversampling_rate: 3

#Dimension of the pixel PSF postage stamp
output_dim: 32

#Dimension of the OPD/Wavefront space."
pupil_diameter: 256

#Boolean to define if we use sample weights based on the noise standard deviation estimation
use_sample_weights: True

#Interpolation type for the physical poly model. Options are: 'none', 'all', 'top_K', 'independent_Zk'."
interpolation_type: None

# SED intepolation points per bin
sed_interp_pts_per_bin: 0

# SED extrapolate
sed_extrapolate: True

# SED interpolate kind
sed_interp_kind: linear

# Standard deviation of the multiplicative SED Gaussian noise.
sed_sigma: 0

#Limits of the PSF field coordinates for the x axis.
x_lims: [0.0, 1.0e+3]

#Limits of the PSF field coordinates for the y axis.
y_lims: [0.0, 1.0e+3]

# Hyperparameters for Parametric model
param_hparams:
# Random seed for Tensor Flow Initialization
random_seed: 3877572

# Parameter for the l2 loss function for the Optical path differences (OPD)/WFE
l2_param: 0.

#Zernike polynomial modes to use on the parametric part.
n_zernikes: 15

#Max polynomial degree of the parametric part. chg to max_deg_param
d_max: 2

#Flag to save optimisation history for parametric model
save_optim_history_param: true

# Hyperparameters for non-parametric model
nonparam_hparams:

#Max polynomial degree of the non-parametric part. chg to max_deg_nonparam
d_max_nonparam: 5

# Number of graph features
num_graph_features: 10

#L1 regularisation parameter for the non-parametric part."
l1_rate: 1.0e-8

#Flag to enable Projected learning for DD_features to be used with `poly` or `semiparametric` model.
project_dd_features: False

#Flag to reset DD_features to be used with `poly` or `semiparametric` model
reset_dd_features: False

#Flag to save optimisation history for non-parametric model
save_optim_history_nonparam: true

# Training hyperparameters
training_hparams:
n_epochs_params: [2, 2, 2]

n_epochs_non_params: [2, 2, 2]

batch_size: 32

multi_cycle_params:

# Total amount of cycles to perform.
total_cycles: 2

# Train cycle definition. It can be: 'parametric', 'non-parametric', 'complete', 'only-non-parametric' and 'only-parametric'."
cycle_def: complete

# Make checkpoint at every cycle or just save the checkpoint at the end of the training."
save_all_cycles: True

#"Saved cycle to use for the evaluation. Can be 'cycle1', 'cycle2', ..."
saved_cycle: cycle2

# Learning rates for the parametric parts. It should be a str where numeric values are separated by spaces.
learning_rate_params: [1.0e-2, 1.0e-2]

# Learning rates for the non-parametric parts. It should be a str where numeric values are separated by spaces."
learning_rate_non_params: [1.0e-1, 1.0e-1]

# Number of training epochs of the parametric parts. It should be a strign where numeric values are separated by spaces."
n_epochs_params: [20, 20]

# Number of training epochs of the non-parametric parts. It should be a str where numeric values are separated by spaces."
n_epochs_non_params: [100, 120]

23 changes: 23 additions & 0 deletions src/wf_psf/tests/psf_models_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
"""UNIT TESTS FOR PACKAGE MODULE: PSF MODELS.
This module contains unit tests for the wf_psf.psf_models psf_models module.
:Author: Jennifer Pollack <[email protected]>
"""

import pytest
from wf_psf.psf_models import psf_models
from wf_psf.utils.io import FileIOHandler
import os


def test_get_psf_model_weights_filepath():
weights_filepath = "src/wf_psf/tests/data/validation/main_random_seed/checkpoint/checkpoint*_poly*_sample_w_bis1_2k_cycle2*"

ans = psf_models.get_psf_model_weights_filepath(weights_filepath)
assert (
ans
== "src/wf_psf/tests/data/validation/main_random_seed/checkpoint/checkpoint_callback_poly_sample_w_bis1_2k_cycle2"
)
Loading

0 comments on commit 997b322

Please sign in to comment.