From b0da103dc3b1759388333d0692511d6b289da30a Mon Sep 17 00:00:00 2001 From: Kris Cooper <43237137+KriSun95@users.noreply.github.com> Date: Thu, 19 Sep 2024 09:14:52 -0500 Subject: [PATCH] Astropy model fitting (#155) * Initial structure. * Added docstrings. * Putting together an example * Numpy has moved their expections to another module so added a try statement to avoid a crash. * Finished documenting the example. * Pre-commit fixes * Added some docs mentioning no error is used in the fitting. * Create 155.feature.rst Add changelog. * Small edits. * Adopting Shanes review suggestions. * Ran tox * Codespell showed I cannot spell. * Removed the fitting example in the wrong folder. * Finishing up tests and some general tidying up. * Some tox tidying up. * Added docs/sg_execution_times.rst file to gitignore. * Deleted docs/sg_execution_times.rst file. --------- Co-authored-by: Kristopher Cooper --- .gitignore | 1 + changelog/155.feature.rst | 1 + examples/fitting_simulated_data.py | 197 ++++++++++++++++++ sunkit_spex/data/README.rst | 2 + sunkit_spex/data/simulated_data.py | 53 +++++ .../fitting/objective_functions/__init__.py | 0 .../optimising_functions.py | 36 ++++ .../fitting/optimizer_tools/__init__.py | 0 .../optimizer_tools/minimizer_tools.py | 38 ++++ sunkit_spex/fitting/statistics/__init__.py | 0 sunkit_spex/fitting/statistics/gaussian.py | 29 +++ sunkit_spex/fitting/tests/__init__.py | 4 + .../fitting/tests/test_objective_functions.py | 39 ++++ .../fitting/tests/test_optimizer_tools.py | 31 +++ sunkit_spex/fitting/tests/test_statistics.py | 30 +++ sunkit_spex/models/instrument_response.py | 14 ++ sunkit_spex/models/models.py | 27 +++ sunkit_spex/tests/test_data.py | 28 +++ sunkit_spex/tests/test_models.py | 54 +++++ 19 files changed, 584 insertions(+) create mode 100644 changelog/155.feature.rst create mode 100644 examples/fitting_simulated_data.py create mode 100644 sunkit_spex/data/simulated_data.py create mode 100644 sunkit_spex/fitting/objective_functions/__init__.py create mode 100644 sunkit_spex/fitting/objective_functions/optimising_functions.py create mode 100644 sunkit_spex/fitting/optimizer_tools/__init__.py create mode 100644 sunkit_spex/fitting/optimizer_tools/minimizer_tools.py create mode 100644 sunkit_spex/fitting/statistics/__init__.py create mode 100644 sunkit_spex/fitting/statistics/gaussian.py create mode 100644 sunkit_spex/fitting/tests/__init__.py create mode 100644 sunkit_spex/fitting/tests/test_objective_functions.py create mode 100644 sunkit_spex/fitting/tests/test_optimizer_tools.py create mode 100644 sunkit_spex/fitting/tests/test_statistics.py create mode 100644 sunkit_spex/tests/test_data.py create mode 100644 sunkit_spex/tests/test_models.py diff --git a/.gitignore b/.gitignore index a4b89a33..35249128 100644 --- a/.gitignore +++ b/.gitignore @@ -147,6 +147,7 @@ docs/_build docs/generated docs/api docs/whatsnew/latest_changelog.txt +docs/sg_execution_times.rst sunkit_spex/version.py htmlcov/ diff --git a/changelog/155.feature.rst b/changelog/155.feature.rst new file mode 100644 index 00000000..4b01d13d --- /dev/null +++ b/changelog/155.feature.rst @@ -0,0 +1 @@ +Add code for using Astropy models in relation to a simple example of forward-fitting X-ray spectroscopy. diff --git a/examples/fitting_simulated_data.py b/examples/fitting_simulated_data.py new file mode 100644 index 00000000..003fac1d --- /dev/null +++ b/examples/fitting_simulated_data.py @@ -0,0 +1,197 @@ +""" +====================== +Fitting Simulated Data +====================== + +This is a file to show a very basic fitting of data where the model are +generated in a different space (photon-space) which are converted using +a square response matrix to the data-space (count-space). + +.. note:: + Caveats: + + * The response is square so the count and photon energy axes are identical. + * No errors are included in the fitting statistic. + +""" + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.colors import LogNorm + +from astropy.modeling import fitting + +from sunkit_spex.data.simulated_data import simulate_square_response_matrix +from sunkit_spex.fitting.objective_functions.optimising_functions import minimize_func +from sunkit_spex.fitting.optimizer_tools.minimizer_tools import scipy_minimize +from sunkit_spex.fitting.statistics.gaussian import chi_squared +from sunkit_spex.models.instrument_response import MatrixModel +from sunkit_spex.models.models import GaussianModel, StraightLineModel + +##################################################### +# +# Start by creating simulated data and instrument. +# This would all be provided by a given observation. +# +# Can define the photon energies + +start, inc = 1.6, 0.04 +stop = 80 + inc / 2 +ph_energies = np.arange(start, stop, inc) + +##################################################### +# +# Let's start making a simulated photon spectrum + +sim_cont = {"slope": -1, "intercept": 100} +sim_line = {"amplitude": 100, "mean": 30, "stddev": 2} +# use a straight line model for a continuum, Gaussian for a line +ph_model = StraightLineModel(**sim_cont) + GaussianModel(**sim_line) + +plt.figure() +plt.plot(ph_energies, ph_model(ph_energies)) +plt.xlabel("Energy [keV]") +plt.ylabel("ph s$^{-1}$ cm$^{-2}$ keV$^{-1}$") +plt.title("Simulated Photon Spectrum") +plt.show() + +##################################################### +# +# Now want a response matrix + +srm = simulate_square_response_matrix(ph_energies.size) +srm_model = MatrixModel(matrix=srm) + +plt.figure() +plt.imshow( + srm, origin="lower", extent=[ph_energies[0], ph_energies[-1], ph_energies[0], ph_energies[-1]], norm=LogNorm() +) +plt.ylabel("Photon Energies [keV]") +plt.xlabel("Count Energies [keV]") +plt.title("Simulated SRM") +plt.show() + +##################################################### +# +# Start work on a count model + +sim_gauss = {"amplitude": 70, "mean": 40, "stddev": 2} +# the brackets are very necessary +ct_model = (ph_model | srm_model) + GaussianModel(**sim_gauss) + +##################################################### +# +# Generate simulated count data to (almost) fit + +sim_count_model = ct_model(ph_energies) + +##################################################### +# +# Add some noise +np_rand = np.random.default_rng(seed=10) +sim_count_model_wn = sim_count_model + (2 * np_rand.random(sim_count_model.size) - 1) * np.sqrt(sim_count_model) + +##################################################### +# +# Can plot all the different components in the simulated count spectrum + +plt.figure() +plt.plot(ph_energies, (ph_model | srm_model)(ph_energies), label="photon model features") +plt.plot(ph_energies, GaussianModel(**sim_gauss)(ph_energies), label="gaussian feature") +plt.plot(ph_energies, sim_count_model, label="total sim. spectrum") +plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise", lw=0.5) +plt.xlabel("Energy [keV]") +plt.ylabel("cts s$^{-1}$ keV$^{-1}$") +plt.title("Simulated Count Spectrum") +plt.legend() + +plt.text(80, 170, "(ph_model(sl,in,am1,mn1,sd1) | srm)", ha="right", c="tab:blue", weight="bold") +plt.text(80, 150, "+ Gaussian(am2,mn2,sd2)", ha="right", c="tab:orange", weight="bold") +plt.show() + +##################################################### +# +# Now we have the simulated data, let's start setting up to fit it +# +# Get some initial guesses that are off from the simulated data above + +guess_cont = {"slope": -0.5, "intercept": 80} +guess_line = {"amplitude": 150, "mean": 32, "stddev": 5} +guess_gauss = {"amplitude": 350, "mean": 39, "stddev": 0.5} + +##################################################### +# +# Define a new model since we have a rough idea of the mode we should use + +ph_mod_4fit = StraightLineModel(**guess_cont) + GaussianModel(**guess_line) +count_model_4fit = (ph_mod_4fit | srm_model) + GaussianModel(**guess_gauss) + +##################################################### +# +# Let's fit the simulated data and plot the result + +opt_res = scipy_minimize( + minimize_func, count_model_4fit.parameters, (sim_count_model_wn, ph_energies, count_model_4fit, chi_squared) +) + +plt.figure() +plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise") +plt.plot(ph_energies, count_model_4fit.evaluate(ph_energies, *opt_res.x), ls=":", label="model fit") +plt.xlabel("Energy [keV]") +plt.ylabel("cts s$^{-1}$ keV$^{-1}$") +plt.title("Simulated Count Spectrum Fit with Scipy") +plt.legend() +plt.show() + + +##################################################### +# +# Now try and fit with Astropy native fitting infrastructure and plot the result +# +# Try and ensure we start fresh with new model definitions + +ph_mod_4astropyfit = StraightLineModel(**guess_cont) + GaussianModel(**guess_line) +count_model_4astropyfit = (ph_mod_4fit | srm_model) + GaussianModel(**guess_gauss) + +astropy_fit = fitting.LevMarLSQFitter() + +astropy_fitted_result = astropy_fit(count_model_4astropyfit, ph_energies, sim_count_model_wn) + +plt.figure() +plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise") +plt.plot(ph_energies, astropy_fitted_result(ph_energies), ls=":", label="model fit") +plt.xlabel("Energy [keV]") +plt.ylabel("cts s$^{-1}$ keV$^{-1}$") +plt.title("Simulated Count Spectrum Fit with Astropy") +plt.legend() +plt.show() + +##################################################### +# +# Display a table of the fitted results + +plt.figure(layout="constrained") + +row_labels = tuple(sim_cont) + tuple(f"{p}1" for p in tuple(sim_line)) + tuple(f"{p}2" for p in tuple(sim_gauss)) +column_labels = ("True Values", "Guess Values", "Scipy Fit", "Astropy Fit") +true_vals = np.array(tuple(sim_cont.values()) + tuple(sim_line.values()) + tuple(sim_gauss.values())) +guess_vals = np.array(tuple(guess_cont.values()) + tuple(guess_line.values()) + tuple(guess_gauss.values())) +scipy_vals = opt_res.x +astropy_vals = astropy_fitted_result.parameters +cell_vals = np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T +cell_text = np.round(np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T, 2).astype(str) + +plt.axis("off") +plt.table( + cellText=cell_text, + cellColours=None, + cellLoc="center", + rowLabels=row_labels, + rowColours=None, + colLabels=column_labels, + colColours=None, + colLoc="center", + bbox=[0, 0, 1, 1], +) + +plt.show() diff --git a/sunkit_spex/data/README.rst b/sunkit_spex/data/README.rst index 382f6e76..58ca9fbe 100644 --- a/sunkit_spex/data/README.rst +++ b/sunkit_spex/data/README.rst @@ -4,3 +4,5 @@ Data directory This directory contains data files included with the package source code distribution. Note that this is intended only for relatively small files - large files should be externally hosted and downloaded as needed. + +Code used to generate fake data products is also stored here. diff --git a/sunkit_spex/data/simulated_data.py b/sunkit_spex/data/simulated_data.py new file mode 100644 index 00000000..09282dfc --- /dev/null +++ b/sunkit_spex/data/simulated_data.py @@ -0,0 +1,53 @@ +""" +Module to store functions used to generate simulated data products. +""" + +import numpy as np + +__all__ = ["simulate_square_response_matrix"] + + +def simulate_square_response_matrix(size, random_seed=10): + """Generate a square matrix with off-diagonal terms. + + Returns a product to mimic an instrument response matrix. + + Parameters + ---------- + size : `int` + The length of each side of the square response matrix. + + random_seed : `int`, optional + The seed input for the random number generator. This will accept any value input accepted by `numpy.random.default_rng`. + + Returns + ------- + `numpy.ndarray` + The simulated 2D square response matrix. + """ + np_rand = np.random.default_rng(seed=random_seed) + + # fake SRM + fake_srm = np.identity(size) + + # add some off-diagonal terms + for c, r in enumerate(fake_srm): + # add some features into the fake SRM + off_diag = np_rand.random(c) * 0.005 + + # add a diagonal feature + _x = 50 + if c >= _x: + off_diag[-_x] = np_rand.random(1)[0] + + # add a vertical feature in + _y = 200 + __y = 30 + if c > _y + 100: + off_diag[_y - __y // 2 : _y + __y // 2] = np.arange(2 * (__y // 2)) * np_rand.random(2 * (__y // 2)) * 5e-4 + + # put these features in the fake_srm row and normalize + r[: off_diag.size] = off_diag + r /= np.sum(r) + + return fake_srm diff --git a/sunkit_spex/fitting/objective_functions/__init__.py b/sunkit_spex/fitting/objective_functions/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sunkit_spex/fitting/objective_functions/optimising_functions.py b/sunkit_spex/fitting/objective_functions/optimising_functions.py new file mode 100644 index 00000000..4b631264 --- /dev/null +++ b/sunkit_spex/fitting/objective_functions/optimising_functions.py @@ -0,0 +1,36 @@ +""" +This module contains functions that can evaluate models and return a fit statistic. +""" + +__all__ = ["minimize_func"] + + +def minimize_func(params, data_y, model_x, model_func, statistic_func): + """ + Minimization function. + + Parameters + ---------- + params : `ndarray` + Guesses of the independent variables. + + data_y : `ndarray` + The data to be fitted. + + model_x : `ndarray` + The values at which to evaluate `model_func` at with `params`. + + model_func : `astropy.modeling.core._ModelMeta` + The model being fitted to the data. Crucially will have an + `evaluate` method. + + statistic_func : `function` + The chosen function to compare the data and the model. + + Returns + ------- + `float` + The value to be optimized that compares the model to the data. + """ + model_y = model_func.evaluate(model_x, *params) + return statistic_func(data_y, model_y) diff --git a/sunkit_spex/fitting/optimizer_tools/__init__.py b/sunkit_spex/fitting/optimizer_tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sunkit_spex/fitting/optimizer_tools/minimizer_tools.py b/sunkit_spex/fitting/optimizer_tools/minimizer_tools.py new file mode 100644 index 00000000..bf371888 --- /dev/null +++ b/sunkit_spex/fitting/optimizer_tools/minimizer_tools.py @@ -0,0 +1,38 @@ +""" +This module contains functions to wrap around minimizer tools. +""" + +from scipy.optimize import minimize + +__all__ = ["scipy_minimize"] + + +def scipy_minimize(objective_func, param_guesses, objective_func_args, **kwargs): + """A function to optimize fitted parameters to data. + + Parameters + ---------- + objective_func : `function` + The function to be optimized. + + param_guesses : `ndarray` + Initial guesses of the independent variables. + + objective_func_args : `tuple` + Any arguments required to be passed to the objective function + after the param_guesses. + E.g., `objective_func(param_guesses, *objective_func_args)`. + + kwargs : + Passed to `scipy.optimize.minimize`. + A default value for the method is chosen to be "Nelder-Mead". + + Returns + ------- + `scipy.optimize.OptimizeResult` + The optimized result after comparing the model to the data. + """ + + method = kwargs.pop("method", "Nelder-Mead") + + return minimize(objective_func, param_guesses, args=objective_func_args, method=method, **kwargs) diff --git a/sunkit_spex/fitting/statistics/__init__.py b/sunkit_spex/fitting/statistics/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sunkit_spex/fitting/statistics/gaussian.py b/sunkit_spex/fitting/statistics/gaussian.py new file mode 100644 index 00000000..7288784a --- /dev/null +++ b/sunkit_spex/fitting/statistics/gaussian.py @@ -0,0 +1,29 @@ +""" +This module contains functions that compute a fit statistic between two data-sets. +""" + +import numpy as np + +__all__ = ["chi_squared"] + + +def chi_squared(data_y, model_y): + """ + The form to optimise while fitting. + + * No error included here. * + + Parameters + ---------- + data_y : `ndarray` + The data to be fitted. + + model_y : `ndarray` + The model values being fitted. + + Returns + ------- + `float` + The value to be optimized that compares the model to the data. + """ + return np.sum((data_y - model_y) ** 2) diff --git a/sunkit_spex/fitting/tests/__init__.py b/sunkit_spex/fitting/tests/__init__.py new file mode 100644 index 00000000..838b4573 --- /dev/null +++ b/sunkit_spex/fitting/tests/__init__.py @@ -0,0 +1,4 @@ +# Licensed under a 3-clause BSD style license - see LICENSE.rst +""" +This module contains package tests. +""" diff --git a/sunkit_spex/fitting/tests/test_objective_functions.py b/sunkit_spex/fitting/tests/test_objective_functions.py new file mode 100644 index 00000000..516918d0 --- /dev/null +++ b/sunkit_spex/fitting/tests/test_objective_functions.py @@ -0,0 +1,39 @@ +""" +This module contains package tests for the objective functions. +""" + +import numpy as np + +from sunkit_spex.fitting.objective_functions.optimising_functions import minimize_func +from sunkit_spex.fitting.statistics.gaussian import chi_squared +from sunkit_spex.models.models import StraightLineModel + + +def test_minimize_func(): + """Test the `minimize_func` function against known outputs.""" + sim_x0 = np.arange(3) + model_params0 = {"slope": 1, "intercept": 0} + sim_model0 = StraightLineModel(**model_params0) + sim_data0 = sim_model0.evaluate(sim_x0, **model_params0) + res0 = minimize_func( + params=tuple(model_params0.values()), + data_y=sim_data0, + model_x=sim_x0, + model_func=sim_model0, + statistic_func=chi_squared, + ) + + sim_x1 = np.arange(3) + model_params1 = {"slope": 1, "intercept": 0} + sim_model1 = StraightLineModel(**model_params1) + sim_data1 = sim_model1.evaluate(sim_x1, **model_params1)[::-1] + res1 = minimize_func( + params=tuple(model_params1.values()), + data_y=sim_data1, + model_x=sim_x1, + model_func=sim_model1, + statistic_func=chi_squared, + ) + + assert res0 == 0 + assert res1 == 8 diff --git a/sunkit_spex/fitting/tests/test_optimizer_tools.py b/sunkit_spex/fitting/tests/test_optimizer_tools.py new file mode 100644 index 00000000..04eabc91 --- /dev/null +++ b/sunkit_spex/fitting/tests/test_optimizer_tools.py @@ -0,0 +1,31 @@ +""" +This module contains package tests for the optimizer functions. +""" + +import numpy as np +from numpy.testing import assert_allclose + +from sunkit_spex.fitting.objective_functions.optimising_functions import minimize_func +from sunkit_spex.fitting.optimizer_tools.minimizer_tools import scipy_minimize +from sunkit_spex.fitting.statistics.gaussian import chi_squared +from sunkit_spex.models.models import StraightLineModel + + +def test_scipy_minimize(): + """Test the `scipy_minimize` function against known outputs.""" + sim_x0 = np.arange(3) + model_params0 = {"slope": 1, "intercept": 0} + model_param_values0 = tuple(model_params0.values()) + sim_model0 = StraightLineModel(**model_params0) + sim_data0 = sim_model0.evaluate(sim_x0, **model_params0) + opt_res0 = scipy_minimize(minimize_func, model_param_values0, (sim_data0, sim_x0, sim_model0, chi_squared)) + + sim_x1 = np.arange(3) + model_params1 = {"slope": 8, "intercept": 5} + model_param_values1 = tuple(model_params1.values()) + sim_model1 = StraightLineModel(**model_params1) + sim_data1 = sim_model1.evaluate(sim_x1, **model_params1) + opt_res1 = scipy_minimize(minimize_func, model_param_values1, (sim_data1, sim_x1, sim_model1, chi_squared)) + + assert_allclose(opt_res0.x, model_param_values0, rtol=1e-3) + assert_allclose(opt_res1.x, model_param_values1, rtol=1e-3) diff --git a/sunkit_spex/fitting/tests/test_statistics.py b/sunkit_spex/fitting/tests/test_statistics.py new file mode 100644 index 00000000..16b56946 --- /dev/null +++ b/sunkit_spex/fitting/tests/test_statistics.py @@ -0,0 +1,30 @@ +""" +This module contains package tests for the statistics functions. +""" + +import numpy as np + +from sunkit_spex.fitting.statistics.gaussian import chi_squared + + +def test_chi_squared(): + sim_data0 = np.array([0]) + sim_model0 = sim_data0 + chi_s0 = chi_squared(sim_data0, sim_model0) + + sim_data1 = np.array([1]) + sim_model1 = sim_data1 + chi_s1 = chi_squared(sim_data1, sim_model1) + + sim_data2 = np.array([1, 2, 3]) + sim_model2 = sim_data2 + chi_s2 = chi_squared(sim_data2, sim_model2) + + sim_data3 = np.array([1, 2, 3]) + sim_model3 = sim_data3[::-1] + chi_s3 = chi_squared(sim_data3, sim_model3) + + assert chi_s0 == 0 + assert chi_s1 == 0 + assert chi_s2 == 0 + assert chi_s3 == 8 diff --git a/sunkit_spex/models/instrument_response.py b/sunkit_spex/models/instrument_response.py index ff473b4a..f5004b14 100644 --- a/sunkit_spex/models/instrument_response.py +++ b/sunkit_spex/models/instrument_response.py @@ -1 +1,15 @@ """Module for model components required for instrument response models.""" + +from astropy.modeling import Fittable1DModel, Parameter + +__all__ = ["MatrixModel"] + + +class MatrixModel(Fittable1DModel): + def __init__(self, matrix): + self.matrix = Parameter(default=matrix, description="The matrix with which to multiply the input.", fixed=True) + super().__init__() + + def evaluate(self, model_y): + # Requires input must have a specific dimensionality + return model_y @ self.matrix diff --git a/sunkit_spex/models/models.py b/sunkit_spex/models/models.py index c0084f3a..47055cc6 100644 --- a/sunkit_spex/models/models.py +++ b/sunkit_spex/models/models.py @@ -1 +1,28 @@ """Module for generic mathematical models.""" + +import numpy as np + +from astropy.modeling import Fittable1DModel, Parameter + +__all__ = ["StraightLineModel", "GaussianModel"] + + +class StraightLineModel(Fittable1DModel): + slope = Parameter(default=1, description="Gradient of a straight line model.") + intercept = Parameter(default=0, description="Y-intercept of a straight line model.") + + @staticmethod + def evaluate(x, slope, intercept): + """Evaluate the straight line model at `x` with parameters `slope` and `intercept`.""" + return slope * x + intercept + + +class GaussianModel(Fittable1DModel): + amplitude = Parameter(default=1, min=0, description="Scalar for Gaussian.") + mean = Parameter(default=0, min=0, description="X-offset for Gaussian.") + stddev = Parameter(default=1, description="Sigma for Gaussian.") + + @staticmethod + def evaluate(x, amplitude, mean, stddev): + """Evaluate the Gaussian model at `x` with parameters `amplitude`, `mean`, and `stddev`.""" + return amplitude * np.e ** (-((x - mean) ** 2) / (2 * stddev**2)) diff --git a/sunkit_spex/tests/test_data.py b/sunkit_spex/tests/test_data.py new file mode 100644 index 00000000..1dfe07e8 --- /dev/null +++ b/sunkit_spex/tests/test_data.py @@ -0,0 +1,28 @@ +""" +This module contains package tests for the data functions. +""" + +import numpy as np +from numpy.testing import assert_allclose + +from sunkit_spex.data.simulated_data import simulate_square_response_matrix + + +def test_simulate_square_response_matrix(): + """Ensure `simulate_square_response_matrix` behaviour does not change.""" + array0 = simulate_square_response_matrix(0) + exp_res0 = np.identity(0) + + array1 = simulate_square_response_matrix(1) + exp_res1 = [[1]] + + array2 = simulate_square_response_matrix(2) + exp_res2 = [[1, 0], [0.00475727, 0.99524273]] + + array3 = simulate_square_response_matrix(3) + exp_res3 = [[1, 0, 0.0], [0.00475727, 0.99524273, 0.0], [0.00103306, 0.00412088, 0.99484607]] + + assert_allclose(array0, exp_res0, rtol=1e-3) + assert_allclose(array1, exp_res1, rtol=1e-3) + assert_allclose(array2, exp_res2, rtol=1e-3) + assert_allclose(array3, exp_res3, rtol=1e-3) diff --git a/sunkit_spex/tests/test_models.py b/sunkit_spex/tests/test_models.py new file mode 100644 index 00000000..a9663662 --- /dev/null +++ b/sunkit_spex/tests/test_models.py @@ -0,0 +1,54 @@ +""" +This module contains package tests for package models. +""" + +import numpy as np +from numpy.testing import assert_allclose, assert_array_equal + +from sunkit_spex.data.simulated_data import simulate_square_response_matrix +from sunkit_spex.models.instrument_response import MatrixModel +from sunkit_spex.models.models import GaussianModel, StraightLineModel + + +def test_StraightLineModel(): + """Test the straight line model evaluation methods to a known output.""" + sim_x0 = np.arange(3) + model_params0 = {"slope": 1, "intercept": 0} + sim_model0 = StraightLineModel(**model_params0) + exp_res0 = [0, 1, 2] + ans0_0 = sim_model0(sim_x0) + ans0_1 = sim_model0.evaluate(sim_x0, *tuple(model_params0.values())) + + assert_allclose(exp_res0, ans0_0, rtol=1e-3) + assert_allclose(ans0_0, ans0_1, rtol=1e-3) + + +def test_GaussianModel(): + """Test the Gaussian model evaluation methods to a known output.""" + sim_x0 = np.arange(-1, 2) * np.sqrt(2 * np.log(2)) + model_params0 = {"amplitude": 10, "mean": 0, "stddev": 1} + sim_model0 = GaussianModel(**model_params0) + exp_res0 = [5, 10, 5] + ans0_0 = sim_model0(sim_x0) + ans0_1 = sim_model0.evaluate(sim_x0, *tuple(model_params0.values())) + + assert_allclose(exp_res0, ans0_0, rtol=1e-3) + assert_allclose(ans0_0, ans0_1, rtol=1e-3) + + +def test_MatrixModel(): + """Test the matrix model contents and compound model behaviour.""" + size0 = 3 + srm0 = simulate_square_response_matrix(size0) + srm_model0 = MatrixModel(matrix=srm0) + + assert_array_equal(srm_model0.matrix, srm0) + + sim_x0 = np.arange(size0) + model_params0 = {"slope": 1, "intercept": 0} + sim_model0 = StraightLineModel(**model_params0) + comp_model0 = sim_model0 | srm_model0 + comp_res0 = comp_model0(sim_x0) + exp_res0 = [0.00682338, 1.00348448, 1.98969213] + + assert_allclose(comp_res0, exp_res0, rtol=1e-6)