-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Initial structure. * Added docstrings. * Putting together an example * Numpy has moved their expections to another module so added a try statement to avoid a crash. * Finished documenting the example. * Pre-commit fixes * Added some docs mentioning no error is used in the fitting. * Create 155.feature.rst Add changelog. * Small edits. * Adopting Shanes review suggestions. * Ran tox * Codespell showed I cannot spell. * Removed the fitting example in the wrong folder. * Finishing up tests and some general tidying up. * Some tox tidying up. * Added docs/sg_execution_times.rst file to gitignore. * Deleted docs/sg_execution_times.rst file. --------- Co-authored-by: Kristopher Cooper <[email protected]>
- Loading branch information
Showing
19 changed files
with
584 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Add code for using Astropy models in relation to a simple example of forward-fitting X-ray spectroscopy. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
""" | ||
====================== | ||
Fitting Simulated Data | ||
====================== | ||
This is a file to show a very basic fitting of data where the model are | ||
generated in a different space (photon-space) which are converted using | ||
a square response matrix to the data-space (count-space). | ||
.. note:: | ||
Caveats: | ||
* The response is square so the count and photon energy axes are identical. | ||
* No errors are included in the fitting statistic. | ||
""" | ||
|
||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from matplotlib.colors import LogNorm | ||
|
||
from astropy.modeling import fitting | ||
|
||
from sunkit_spex.data.simulated_data import simulate_square_response_matrix | ||
from sunkit_spex.fitting.objective_functions.optimising_functions import minimize_func | ||
from sunkit_spex.fitting.optimizer_tools.minimizer_tools import scipy_minimize | ||
from sunkit_spex.fitting.statistics.gaussian import chi_squared | ||
from sunkit_spex.models.instrument_response import MatrixModel | ||
from sunkit_spex.models.models import GaussianModel, StraightLineModel | ||
|
||
##################################################### | ||
# | ||
# Start by creating simulated data and instrument. | ||
# This would all be provided by a given observation. | ||
# | ||
# Can define the photon energies | ||
|
||
start, inc = 1.6, 0.04 | ||
stop = 80 + inc / 2 | ||
ph_energies = np.arange(start, stop, inc) | ||
|
||
##################################################### | ||
# | ||
# Let's start making a simulated photon spectrum | ||
|
||
sim_cont = {"slope": -1, "intercept": 100} | ||
sim_line = {"amplitude": 100, "mean": 30, "stddev": 2} | ||
# use a straight line model for a continuum, Gaussian for a line | ||
ph_model = StraightLineModel(**sim_cont) + GaussianModel(**sim_line) | ||
|
||
plt.figure() | ||
plt.plot(ph_energies, ph_model(ph_energies)) | ||
plt.xlabel("Energy [keV]") | ||
plt.ylabel("ph s$^{-1}$ cm$^{-2}$ keV$^{-1}$") | ||
plt.title("Simulated Photon Spectrum") | ||
plt.show() | ||
|
||
##################################################### | ||
# | ||
# Now want a response matrix | ||
|
||
srm = simulate_square_response_matrix(ph_energies.size) | ||
srm_model = MatrixModel(matrix=srm) | ||
|
||
plt.figure() | ||
plt.imshow( | ||
srm, origin="lower", extent=[ph_energies[0], ph_energies[-1], ph_energies[0], ph_energies[-1]], norm=LogNorm() | ||
) | ||
plt.ylabel("Photon Energies [keV]") | ||
plt.xlabel("Count Energies [keV]") | ||
plt.title("Simulated SRM") | ||
plt.show() | ||
|
||
##################################################### | ||
# | ||
# Start work on a count model | ||
|
||
sim_gauss = {"amplitude": 70, "mean": 40, "stddev": 2} | ||
# the brackets are very necessary | ||
ct_model = (ph_model | srm_model) + GaussianModel(**sim_gauss) | ||
|
||
##################################################### | ||
# | ||
# Generate simulated count data to (almost) fit | ||
|
||
sim_count_model = ct_model(ph_energies) | ||
|
||
##################################################### | ||
# | ||
# Add some noise | ||
np_rand = np.random.default_rng(seed=10) | ||
sim_count_model_wn = sim_count_model + (2 * np_rand.random(sim_count_model.size) - 1) * np.sqrt(sim_count_model) | ||
|
||
##################################################### | ||
# | ||
# Can plot all the different components in the simulated count spectrum | ||
|
||
plt.figure() | ||
plt.plot(ph_energies, (ph_model | srm_model)(ph_energies), label="photon model features") | ||
plt.plot(ph_energies, GaussianModel(**sim_gauss)(ph_energies), label="gaussian feature") | ||
plt.plot(ph_energies, sim_count_model, label="total sim. spectrum") | ||
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise", lw=0.5) | ||
plt.xlabel("Energy [keV]") | ||
plt.ylabel("cts s$^{-1}$ keV$^{-1}$") | ||
plt.title("Simulated Count Spectrum") | ||
plt.legend() | ||
|
||
plt.text(80, 170, "(ph_model(sl,in,am1,mn1,sd1) | srm)", ha="right", c="tab:blue", weight="bold") | ||
plt.text(80, 150, "+ Gaussian(am2,mn2,sd2)", ha="right", c="tab:orange", weight="bold") | ||
plt.show() | ||
|
||
##################################################### | ||
# | ||
# Now we have the simulated data, let's start setting up to fit it | ||
# | ||
# Get some initial guesses that are off from the simulated data above | ||
|
||
guess_cont = {"slope": -0.5, "intercept": 80} | ||
guess_line = {"amplitude": 150, "mean": 32, "stddev": 5} | ||
guess_gauss = {"amplitude": 350, "mean": 39, "stddev": 0.5} | ||
|
||
##################################################### | ||
# | ||
# Define a new model since we have a rough idea of the mode we should use | ||
|
||
ph_mod_4fit = StraightLineModel(**guess_cont) + GaussianModel(**guess_line) | ||
count_model_4fit = (ph_mod_4fit | srm_model) + GaussianModel(**guess_gauss) | ||
|
||
##################################################### | ||
# | ||
# Let's fit the simulated data and plot the result | ||
|
||
opt_res = scipy_minimize( | ||
minimize_func, count_model_4fit.parameters, (sim_count_model_wn, ph_energies, count_model_4fit, chi_squared) | ||
) | ||
|
||
plt.figure() | ||
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise") | ||
plt.plot(ph_energies, count_model_4fit.evaluate(ph_energies, *opt_res.x), ls=":", label="model fit") | ||
plt.xlabel("Energy [keV]") | ||
plt.ylabel("cts s$^{-1}$ keV$^{-1}$") | ||
plt.title("Simulated Count Spectrum Fit with Scipy") | ||
plt.legend() | ||
plt.show() | ||
|
||
|
||
##################################################### | ||
# | ||
# Now try and fit with Astropy native fitting infrastructure and plot the result | ||
# | ||
# Try and ensure we start fresh with new model definitions | ||
|
||
ph_mod_4astropyfit = StraightLineModel(**guess_cont) + GaussianModel(**guess_line) | ||
count_model_4astropyfit = (ph_mod_4fit | srm_model) + GaussianModel(**guess_gauss) | ||
|
||
astropy_fit = fitting.LevMarLSQFitter() | ||
|
||
astropy_fitted_result = astropy_fit(count_model_4astropyfit, ph_energies, sim_count_model_wn) | ||
|
||
plt.figure() | ||
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise") | ||
plt.plot(ph_energies, astropy_fitted_result(ph_energies), ls=":", label="model fit") | ||
plt.xlabel("Energy [keV]") | ||
plt.ylabel("cts s$^{-1}$ keV$^{-1}$") | ||
plt.title("Simulated Count Spectrum Fit with Astropy") | ||
plt.legend() | ||
plt.show() | ||
|
||
##################################################### | ||
# | ||
# Display a table of the fitted results | ||
|
||
plt.figure(layout="constrained") | ||
|
||
row_labels = tuple(sim_cont) + tuple(f"{p}1" for p in tuple(sim_line)) + tuple(f"{p}2" for p in tuple(sim_gauss)) | ||
column_labels = ("True Values", "Guess Values", "Scipy Fit", "Astropy Fit") | ||
true_vals = np.array(tuple(sim_cont.values()) + tuple(sim_line.values()) + tuple(sim_gauss.values())) | ||
guess_vals = np.array(tuple(guess_cont.values()) + tuple(guess_line.values()) + tuple(guess_gauss.values())) | ||
scipy_vals = opt_res.x | ||
astropy_vals = astropy_fitted_result.parameters | ||
cell_vals = np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T | ||
cell_text = np.round(np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T, 2).astype(str) | ||
|
||
plt.axis("off") | ||
plt.table( | ||
cellText=cell_text, | ||
cellColours=None, | ||
cellLoc="center", | ||
rowLabels=row_labels, | ||
rowColours=None, | ||
colLabels=column_labels, | ||
colColours=None, | ||
colLoc="center", | ||
bbox=[0, 0, 1, 1], | ||
) | ||
|
||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
""" | ||
Module to store functions used to generate simulated data products. | ||
""" | ||
|
||
import numpy as np | ||
|
||
__all__ = ["simulate_square_response_matrix"] | ||
|
||
|
||
def simulate_square_response_matrix(size, random_seed=10): | ||
"""Generate a square matrix with off-diagonal terms. | ||
Returns a product to mimic an instrument response matrix. | ||
Parameters | ||
---------- | ||
size : `int` | ||
The length of each side of the square response matrix. | ||
random_seed : `int`, optional | ||
The seed input for the random number generator. This will accept any value input accepted by `numpy.random.default_rng`. | ||
Returns | ||
------- | ||
`numpy.ndarray` | ||
The simulated 2D square response matrix. | ||
""" | ||
np_rand = np.random.default_rng(seed=random_seed) | ||
|
||
# fake SRM | ||
fake_srm = np.identity(size) | ||
|
||
# add some off-diagonal terms | ||
for c, r in enumerate(fake_srm): | ||
# add some features into the fake SRM | ||
off_diag = np_rand.random(c) * 0.005 | ||
|
||
# add a diagonal feature | ||
_x = 50 | ||
if c >= _x: | ||
off_diag[-_x] = np_rand.random(1)[0] | ||
|
||
# add a vertical feature in | ||
_y = 200 | ||
__y = 30 | ||
if c > _y + 100: | ||
off_diag[_y - __y // 2 : _y + __y // 2] = np.arange(2 * (__y // 2)) * np_rand.random(2 * (__y // 2)) * 5e-4 | ||
|
||
# put these features in the fake_srm row and normalize | ||
r[: off_diag.size] = off_diag | ||
r /= np.sum(r) | ||
|
||
return fake_srm |
Empty file.
36 changes: 36 additions & 0 deletions
36
sunkit_spex/fitting/objective_functions/optimising_functions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
""" | ||
This module contains functions that can evaluate models and return a fit statistic. | ||
""" | ||
|
||
__all__ = ["minimize_func"] | ||
|
||
|
||
def minimize_func(params, data_y, model_x, model_func, statistic_func): | ||
""" | ||
Minimization function. | ||
Parameters | ||
---------- | ||
params : `ndarray` | ||
Guesses of the independent variables. | ||
data_y : `ndarray` | ||
The data to be fitted. | ||
model_x : `ndarray` | ||
The values at which to evaluate `model_func` at with `params`. | ||
model_func : `astropy.modeling.core._ModelMeta` | ||
The model being fitted to the data. Crucially will have an | ||
`evaluate` method. | ||
statistic_func : `function` | ||
The chosen function to compare the data and the model. | ||
Returns | ||
------- | ||
`float` | ||
The value to be optimized that compares the model to the data. | ||
""" | ||
model_y = model_func.evaluate(model_x, *params) | ||
return statistic_func(data_y, model_y) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
""" | ||
This module contains functions to wrap around minimizer tools. | ||
""" | ||
|
||
from scipy.optimize import minimize | ||
|
||
__all__ = ["scipy_minimize"] | ||
|
||
|
||
def scipy_minimize(objective_func, param_guesses, objective_func_args, **kwargs): | ||
"""A function to optimize fitted parameters to data. | ||
Parameters | ||
---------- | ||
objective_func : `function` | ||
The function to be optimized. | ||
param_guesses : `ndarray` | ||
Initial guesses of the independent variables. | ||
objective_func_args : `tuple` | ||
Any arguments required to be passed to the objective function | ||
after the param_guesses. | ||
E.g., `objective_func(param_guesses, *objective_func_args)`. | ||
kwargs : | ||
Passed to `scipy.optimize.minimize`. | ||
A default value for the method is chosen to be "Nelder-Mead". | ||
Returns | ||
------- | ||
`scipy.optimize.OptimizeResult` | ||
The optimized result after comparing the model to the data. | ||
""" | ||
|
||
method = kwargs.pop("method", "Nelder-Mead") | ||
|
||
return minimize(objective_func, param_guesses, args=objective_func_args, method=method, **kwargs) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
""" | ||
This module contains functions that compute a fit statistic between two data-sets. | ||
""" | ||
|
||
import numpy as np | ||
|
||
__all__ = ["chi_squared"] | ||
|
||
|
||
def chi_squared(data_y, model_y): | ||
""" | ||
The form to optimise while fitting. | ||
* No error included here. * | ||
Parameters | ||
---------- | ||
data_y : `ndarray` | ||
The data to be fitted. | ||
model_y : `ndarray` | ||
The model values being fitted. | ||
Returns | ||
------- | ||
`float` | ||
The value to be optimized that compares the model to the data. | ||
""" | ||
return np.sum((data_y - model_y) ** 2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Licensed under a 3-clause BSD style license - see LICENSE.rst | ||
""" | ||
This module contains package tests. | ||
""" |
Oops, something went wrong.