Skip to content

Commit

Permalink
Astropy model fitting (#155)
Browse files Browse the repository at this point in the history
* Initial structure.

* Added docstrings.

* Putting together an example

* Numpy has moved their expections to another module so added a try statement to avoid a crash.

* Finished documenting the example.

* Pre-commit fixes

* Added some docs mentioning no error is used in the fitting.

* Create 155.feature.rst

Add changelog.

* Small edits.

* Adopting Shanes review suggestions.

* Ran tox

* Codespell showed I cannot spell.

* Removed the fitting example in the wrong folder.

* Finishing up tests and some general tidying up.

* Some tox tidying up.

* Added docs/sg_execution_times.rst file to gitignore.

* Deleted docs/sg_execution_times.rst file.

---------

Co-authored-by: Kristopher Cooper <[email protected]>
  • Loading branch information
KriSun95 and Kristopher Cooper authored Sep 19, 2024
1 parent 375bb9a commit b0da103
Show file tree
Hide file tree
Showing 19 changed files with 584 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ docs/_build
docs/generated
docs/api
docs/whatsnew/latest_changelog.txt
docs/sg_execution_times.rst
sunkit_spex/version.py
htmlcov/

Expand Down
1 change: 1 addition & 0 deletions changelog/155.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add code for using Astropy models in relation to a simple example of forward-fitting X-ray spectroscopy.
197 changes: 197 additions & 0 deletions examples/fitting_simulated_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""
======================
Fitting Simulated Data
======================
This is a file to show a very basic fitting of data where the model are
generated in a different space (photon-space) which are converted using
a square response matrix to the data-space (count-space).
.. note::
Caveats:
* The response is square so the count and photon energy axes are identical.
* No errors are included in the fitting statistic.
"""

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LogNorm

from astropy.modeling import fitting

from sunkit_spex.data.simulated_data import simulate_square_response_matrix
from sunkit_spex.fitting.objective_functions.optimising_functions import minimize_func
from sunkit_spex.fitting.optimizer_tools.minimizer_tools import scipy_minimize
from sunkit_spex.fitting.statistics.gaussian import chi_squared
from sunkit_spex.models.instrument_response import MatrixModel
from sunkit_spex.models.models import GaussianModel, StraightLineModel

#####################################################
#
# Start by creating simulated data and instrument.
# This would all be provided by a given observation.
#
# Can define the photon energies

start, inc = 1.6, 0.04
stop = 80 + inc / 2
ph_energies = np.arange(start, stop, inc)

#####################################################
#
# Let's start making a simulated photon spectrum

sim_cont = {"slope": -1, "intercept": 100}
sim_line = {"amplitude": 100, "mean": 30, "stddev": 2}
# use a straight line model for a continuum, Gaussian for a line
ph_model = StraightLineModel(**sim_cont) + GaussianModel(**sim_line)

plt.figure()
plt.plot(ph_energies, ph_model(ph_energies))
plt.xlabel("Energy [keV]")
plt.ylabel("ph s$^{-1}$ cm$^{-2}$ keV$^{-1}$")
plt.title("Simulated Photon Spectrum")
plt.show()

#####################################################
#
# Now want a response matrix

srm = simulate_square_response_matrix(ph_energies.size)
srm_model = MatrixModel(matrix=srm)

plt.figure()
plt.imshow(
srm, origin="lower", extent=[ph_energies[0], ph_energies[-1], ph_energies[0], ph_energies[-1]], norm=LogNorm()
)
plt.ylabel("Photon Energies [keV]")
plt.xlabel("Count Energies [keV]")
plt.title("Simulated SRM")
plt.show()

#####################################################
#
# Start work on a count model

sim_gauss = {"amplitude": 70, "mean": 40, "stddev": 2}
# the brackets are very necessary
ct_model = (ph_model | srm_model) + GaussianModel(**sim_gauss)

#####################################################
#
# Generate simulated count data to (almost) fit

sim_count_model = ct_model(ph_energies)

#####################################################
#
# Add some noise
np_rand = np.random.default_rng(seed=10)
sim_count_model_wn = sim_count_model + (2 * np_rand.random(sim_count_model.size) - 1) * np.sqrt(sim_count_model)

#####################################################
#
# Can plot all the different components in the simulated count spectrum

plt.figure()
plt.plot(ph_energies, (ph_model | srm_model)(ph_energies), label="photon model features")
plt.plot(ph_energies, GaussianModel(**sim_gauss)(ph_energies), label="gaussian feature")
plt.plot(ph_energies, sim_count_model, label="total sim. spectrum")
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise", lw=0.5)
plt.xlabel("Energy [keV]")
plt.ylabel("cts s$^{-1}$ keV$^{-1}$")
plt.title("Simulated Count Spectrum")
plt.legend()

plt.text(80, 170, "(ph_model(sl,in,am1,mn1,sd1) | srm)", ha="right", c="tab:blue", weight="bold")
plt.text(80, 150, "+ Gaussian(am2,mn2,sd2)", ha="right", c="tab:orange", weight="bold")
plt.show()

#####################################################
#
# Now we have the simulated data, let's start setting up to fit it
#
# Get some initial guesses that are off from the simulated data above

guess_cont = {"slope": -0.5, "intercept": 80}
guess_line = {"amplitude": 150, "mean": 32, "stddev": 5}
guess_gauss = {"amplitude": 350, "mean": 39, "stddev": 0.5}

#####################################################
#
# Define a new model since we have a rough idea of the mode we should use

ph_mod_4fit = StraightLineModel(**guess_cont) + GaussianModel(**guess_line)
count_model_4fit = (ph_mod_4fit | srm_model) + GaussianModel(**guess_gauss)

#####################################################
#
# Let's fit the simulated data and plot the result

opt_res = scipy_minimize(
minimize_func, count_model_4fit.parameters, (sim_count_model_wn, ph_energies, count_model_4fit, chi_squared)
)

plt.figure()
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise")
plt.plot(ph_energies, count_model_4fit.evaluate(ph_energies, *opt_res.x), ls=":", label="model fit")
plt.xlabel("Energy [keV]")
plt.ylabel("cts s$^{-1}$ keV$^{-1}$")
plt.title("Simulated Count Spectrum Fit with Scipy")
plt.legend()
plt.show()


#####################################################
#
# Now try and fit with Astropy native fitting infrastructure and plot the result
#
# Try and ensure we start fresh with new model definitions

ph_mod_4astropyfit = StraightLineModel(**guess_cont) + GaussianModel(**guess_line)
count_model_4astropyfit = (ph_mod_4fit | srm_model) + GaussianModel(**guess_gauss)

astropy_fit = fitting.LevMarLSQFitter()

astropy_fitted_result = astropy_fit(count_model_4astropyfit, ph_energies, sim_count_model_wn)

plt.figure()
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise")
plt.plot(ph_energies, astropy_fitted_result(ph_energies), ls=":", label="model fit")
plt.xlabel("Energy [keV]")
plt.ylabel("cts s$^{-1}$ keV$^{-1}$")
plt.title("Simulated Count Spectrum Fit with Astropy")
plt.legend()
plt.show()

#####################################################
#
# Display a table of the fitted results

plt.figure(layout="constrained")

row_labels = tuple(sim_cont) + tuple(f"{p}1" for p in tuple(sim_line)) + tuple(f"{p}2" for p in tuple(sim_gauss))
column_labels = ("True Values", "Guess Values", "Scipy Fit", "Astropy Fit")
true_vals = np.array(tuple(sim_cont.values()) + tuple(sim_line.values()) + tuple(sim_gauss.values()))
guess_vals = np.array(tuple(guess_cont.values()) + tuple(guess_line.values()) + tuple(guess_gauss.values()))
scipy_vals = opt_res.x
astropy_vals = astropy_fitted_result.parameters
cell_vals = np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T
cell_text = np.round(np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T, 2).astype(str)

plt.axis("off")
plt.table(
cellText=cell_text,
cellColours=None,
cellLoc="center",
rowLabels=row_labels,
rowColours=None,
colLabels=column_labels,
colColours=None,
colLoc="center",
bbox=[0, 0, 1, 1],
)

plt.show()
2 changes: 2 additions & 0 deletions sunkit_spex/data/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,5 @@ Data directory
This directory contains data files included with the package source
code distribution. Note that this is intended only for relatively small files
- large files should be externally hosted and downloaded as needed.

Code used to generate fake data products is also stored here.
53 changes: 53 additions & 0 deletions sunkit_spex/data/simulated_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
Module to store functions used to generate simulated data products.
"""

import numpy as np

__all__ = ["simulate_square_response_matrix"]


def simulate_square_response_matrix(size, random_seed=10):
"""Generate a square matrix with off-diagonal terms.
Returns a product to mimic an instrument response matrix.
Parameters
----------
size : `int`
The length of each side of the square response matrix.
random_seed : `int`, optional
The seed input for the random number generator. This will accept any value input accepted by `numpy.random.default_rng`.
Returns
-------
`numpy.ndarray`
The simulated 2D square response matrix.
"""
np_rand = np.random.default_rng(seed=random_seed)

# fake SRM
fake_srm = np.identity(size)

# add some off-diagonal terms
for c, r in enumerate(fake_srm):
# add some features into the fake SRM
off_diag = np_rand.random(c) * 0.005

# add a diagonal feature
_x = 50
if c >= _x:
off_diag[-_x] = np_rand.random(1)[0]

# add a vertical feature in
_y = 200
__y = 30
if c > _y + 100:
off_diag[_y - __y // 2 : _y + __y // 2] = np.arange(2 * (__y // 2)) * np_rand.random(2 * (__y // 2)) * 5e-4

# put these features in the fake_srm row and normalize
r[: off_diag.size] = off_diag
r /= np.sum(r)

return fake_srm
Empty file.
36 changes: 36 additions & 0 deletions sunkit_spex/fitting/objective_functions/optimising_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""
This module contains functions that can evaluate models and return a fit statistic.
"""

__all__ = ["minimize_func"]


def minimize_func(params, data_y, model_x, model_func, statistic_func):
"""
Minimization function.
Parameters
----------
params : `ndarray`
Guesses of the independent variables.
data_y : `ndarray`
The data to be fitted.
model_x : `ndarray`
The values at which to evaluate `model_func` at with `params`.
model_func : `astropy.modeling.core._ModelMeta`
The model being fitted to the data. Crucially will have an
`evaluate` method.
statistic_func : `function`
The chosen function to compare the data and the model.
Returns
-------
`float`
The value to be optimized that compares the model to the data.
"""
model_y = model_func.evaluate(model_x, *params)
return statistic_func(data_y, model_y)
Empty file.
38 changes: 38 additions & 0 deletions sunkit_spex/fitting/optimizer_tools/minimizer_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""
This module contains functions to wrap around minimizer tools.
"""

from scipy.optimize import minimize

__all__ = ["scipy_minimize"]


def scipy_minimize(objective_func, param_guesses, objective_func_args, **kwargs):
"""A function to optimize fitted parameters to data.
Parameters
----------
objective_func : `function`
The function to be optimized.
param_guesses : `ndarray`
Initial guesses of the independent variables.
objective_func_args : `tuple`
Any arguments required to be passed to the objective function
after the param_guesses.
E.g., `objective_func(param_guesses, *objective_func_args)`.
kwargs :
Passed to `scipy.optimize.minimize`.
A default value for the method is chosen to be "Nelder-Mead".
Returns
-------
`scipy.optimize.OptimizeResult`
The optimized result after comparing the model to the data.
"""

method = kwargs.pop("method", "Nelder-Mead")

return minimize(objective_func, param_guesses, args=objective_func_args, method=method, **kwargs)
Empty file.
29 changes: 29 additions & 0 deletions sunkit_spex/fitting/statistics/gaussian.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""
This module contains functions that compute a fit statistic between two data-sets.
"""

import numpy as np

__all__ = ["chi_squared"]


def chi_squared(data_y, model_y):
"""
The form to optimise while fitting.
* No error included here. *
Parameters
----------
data_y : `ndarray`
The data to be fitted.
model_y : `ndarray`
The model values being fitted.
Returns
-------
`float`
The value to be optimized that compares the model to the data.
"""
return np.sum((data_y - model_y) ** 2)
4 changes: 4 additions & 0 deletions sunkit_spex/fitting/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
This module contains package tests.
"""
Loading

0 comments on commit b0da103

Please sign in to comment.