Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Update example to use spec object and units #163

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 93 additions & 75 deletions examples/fitting_simulated_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,19 @@
import numpy as np
from matplotlib.colors import LogNorm

import astropy.units as u
from astropy.modeling import fitting
from astropy.modeling.functional_models import Gaussian1D, Linear1D
from astropy.visualization import quantity_support

from sunkit_spex.data.simulated_data import simulate_square_response_matrix
from sunkit_spex.fitting.objective_functions.optimising_functions import minimize_func
from sunkit_spex.fitting.optimizer_tools.minimizer_tools import scipy_minimize
from sunkit_spex.fitting.statistics.gaussian import chi_squared
from sunkit_spex.models.instrument_response import MatrixModel
from sunkit_spex.models.models import GaussianModel, StraightLineModel
from sunkit_spex.spectrum import Spectrum
from sunkit_spex.spectrum.spectrum import SpectralAxis

#####################################################
#
Expand All @@ -37,87 +42,102 @@

start, inc = 1.6, 0.04
stop = 80 + inc / 2
ph_energies = np.arange(start, stop, inc)
ph_energies = np.arange(start, stop, inc) * u.keV

#####################################################
#
# Let's start making a simulated photon spectrum

sim_cont = {"slope": -1, "intercept": 100}
sim_line = {"amplitude": 100, "mean": 30, "stddev": 2}
sim_cont = {"slope": -1 * u.ph / u.keV, "intercept": 100 * u.ph}
sim_line = {"amplitude": 100 * u.ph, "mean": 30 * u.keV, "stddev": 2 * u.keV}
# use a straight line model for a continuum, Gaussian for a line
ph_model = StraightLineModel(**sim_cont) + GaussianModel(**sim_line)

plt.figure()
plt.plot(ph_energies, ph_model(ph_energies))
plt.xlabel("Energy [keV]")
plt.ylabel("ph s$^{-1}$ cm$^{-2}$ keV$^{-1}$")
plt.title("Simulated Photon Spectrum")
plt.show()
with quantity_support():
plt.figure()
plt.plot(ph_energies, ph_model(ph_energies))
plt.xlabel(f"Energy [{ph_energies.unit}]")
plt.title("Simulated Photon Spectrum")
plt.show()

#####################################################
#
# Now want a response matrix

srm = simulate_square_response_matrix(ph_energies.size)
srm_model = MatrixModel(matrix=srm)

plt.figure()
plt.imshow(
srm, origin="lower", extent=[ph_energies[0], ph_energies[-1], ph_energies[0], ph_energies[-1]], norm=LogNorm()
srm_model = MatrixModel(
matrix=srm * u.ct / u.ph, input_axis=SpectralAxis(ph_energies), output_axis=SpectralAxis(ph_energies)
)
plt.ylabel("Photon Energies [keV]")
plt.xlabel("Count Energies [keV]")
plt.title("Simulated SRM")
plt.show()

with quantity_support():
plt.figure()
plt.imshow(
srm_model.matrix.value,
origin="lower",
extent=(
srm_model.inputs_axis[0].value,
srm_model.inputs_axis[-1].value,
srm_model.output_axis[0].value,
srm_model.output_axis[-1].value,
),
norm=LogNorm(),
)
plt.ylabel(f"Photon Energies [{srm_model.inputs_axis.unit}]")
plt.xlabel(f"Count Energies [{srm_model.output_axis.unit}]")
plt.title("Simulated SRM")
plt.show()

#####################################################
#
# Start work on a count model

sim_gauss = {"amplitude": 70, "mean": 40, "stddev": 2}
sim_gauss = {"amplitude": 70 * u.ct, "mean": 40 * u.keV, "stddev": 2 * u.keV}
# the brackets are very necessary
ct_model = (ph_model | srm_model) + GaussianModel(**sim_gauss)

#####################################################
#
# Generate simulated count data to (almost) fit

sim_count_model = ct_model(ph_energies)
sim_count_model = ct_model(srm_model.inputs_axis)

#####################################################
#
# Add some noise
np_rand = np.random.default_rng(seed=10)
sim_count_model_wn = sim_count_model + (2 * np_rand.random(sim_count_model.size) - 1) * np.sqrt(sim_count_model)
sim_count_model_wn = (
sim_count_model + (2 * np_rand.random(sim_count_model.size) - 1) * np.sqrt(sim_count_model.value) * u.ct
)

obs_spec = Spectrum(sim_count_model_wn.reshape(-1), spectral_axis=ph_energies)

#####################################################
#
# Can plot all the different components in the simulated count spectrum

plt.figure()
plt.plot(ph_energies, (ph_model | srm_model)(ph_energies), label="photon model features")
plt.plot(ph_energies, GaussianModel(**sim_gauss)(ph_energies), label="gaussian feature")
plt.plot(ph_energies, sim_count_model, label="total sim. spectrum")
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise", lw=0.5)
plt.xlabel("Energy [keV]")
plt.ylabel("cts s$^{-1}$ keV$^{-1}$")
plt.title("Simulated Count Spectrum")
plt.legend()
with quantity_support():
plt.figure()
plt.plot(ph_energies, (ph_model | srm_model)(ph_energies), label="photon model features")
plt.plot(ph_energies, GaussianModel(**sim_gauss)(ph_energies), label="gaussian feature")
plt.plot(ph_energies, sim_count_model, label="total sim. spectrum")
plt.plot(obs_spec._spectral_axis, obs_spec.data, label="total sim. spectrum + noise", lw=0.5)
plt.xlabel(f"Energy [{ph_energies.unit}]")
plt.title("Simulated Count Spectrum")
plt.legend()

plt.text(80, 170, "(ph_model(sl,in,am1,mn1,sd1) | srm)", ha="right", c="tab:blue", weight="bold")
plt.text(80, 150, "+ Gaussian(am2,mn2,sd2)", ha="right", c="tab:orange", weight="bold")
plt.show()
plt.text(80, 170, "(ph_model(sl,in,am1,mn1,sd1) | srm)", ha="right", c="tab:blue", weight="bold")
plt.text(80, 150, "+ Gaussian(am2,mn2,sd2)", ha="right", c="tab:orange", weight="bold")
plt.show()

#####################################################
#
# Now we have the simulated data, let's start setting up to fit it
#
# Get some initial guesses that are off from the simulated data above

guess_cont = {"slope": -0.5, "intercept": 80}
guess_line = {"amplitude": 150, "mean": 32, "stddev": 5}
guess_gauss = {"amplitude": 350, "mean": 39, "stddev": 0.5}
guess_cont = {"slope": -0.5 * u.ph / u.keV, "intercept": 80 * u.ph}
guess_line = {"amplitude": 150 * u.ph, "mean": 32 * u.keV, "stddev": 5 * u.keV}
guess_gauss = {"amplitude": 350 * u.ct, "mean": 39 * u.keV, "stddev": 0.5 * u.keV}

#####################################################
#
Expand All @@ -130,18 +150,16 @@
#
# Let's fit the simulated data and plot the result

opt_res = scipy_minimize(
minimize_func, count_model_4fit.parameters, (sim_count_model_wn, ph_energies, count_model_4fit, chi_squared)
)
opt_res = scipy_minimize(minimize_func, count_model_4fit.parameters, (obs_spec, count_model_4fit, chi_squared))

plt.figure()
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise")
plt.plot(ph_energies, count_model_4fit.evaluate(ph_energies, *opt_res.x), ls=":", label="model fit")
plt.xlabel("Energy [keV]")
plt.ylabel("cts s$^{-1}$ keV$^{-1}$")
plt.title("Simulated Count Spectrum Fit with Scipy")
plt.legend()
plt.show()
with quantity_support():
plt.figure()
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise")
plt.plot(ph_energies, count_model_4fit.evaluate(ph_energies.value, *opt_res.x), ls=":", label="model fit")
plt.xlabel(f"Energy [{ph_energies.unit}]")
plt.title("Simulated Count Spectrum Fit with Scipy")
plt.legend()
plt.show()


#####################################################
Expand All @@ -150,12 +168,12 @@
#
# Try and ensure we start fresh with new model definitions

ph_mod_4astropyfit = StraightLineModel(**guess_cont) + GaussianModel(**guess_line)
count_model_4astropyfit = (ph_mod_4fit | srm_model) + GaussianModel(**guess_gauss)
ph_mod_4astropyfit = Linear1D(**guess_cont) + Gaussian1D(**guess_line)
count_model_4astropyfit = (ph_mod_4astropyfit | srm_model) + Gaussian1D(**guess_gauss)

astropy_fit = fitting.LevMarLSQFitter()

astropy_fitted_result = astropy_fit(count_model_4astropyfit, ph_energies, sim_count_model_wn)
astropy_fitted_result = astropy_fit(count_model_4astropyfit, ph_energies, obs_spec.data << obs_spec.unit)

plt.figure()
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise")
Expand All @@ -170,28 +188,28 @@
#
# Display a table of the fitted results

plt.figure(layout="constrained")

row_labels = tuple(sim_cont) + tuple(f"{p}1" for p in tuple(sim_line)) + tuple(f"{p}2" for p in tuple(sim_gauss))
column_labels = ("True Values", "Guess Values", "Scipy Fit", "Astropy Fit")
true_vals = np.array(tuple(sim_cont.values()) + tuple(sim_line.values()) + tuple(sim_gauss.values()))
guess_vals = np.array(tuple(guess_cont.values()) + tuple(guess_line.values()) + tuple(guess_gauss.values()))
scipy_vals = opt_res.x
astropy_vals = astropy_fitted_result.parameters
cell_vals = np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T
cell_text = np.round(np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T, 2).astype(str)

plt.axis("off")
plt.table(
cellText=cell_text,
cellColours=None,
cellLoc="center",
rowLabels=row_labels,
rowColours=None,
colLabels=column_labels,
colColours=None,
colLoc="center",
bbox=[0, 0, 1, 1],
)

plt.show()
# plt.figure(layout="constrained")
#
# row_labels = tuple(sim_cont) + tuple(f"{p}1" for p in tuple(sim_line)) + tuple(f"{p}2" for p in tuple(sim_gauss))
# column_labels = ("True Values", "Guess Values", "Scipy Fit", "Astropy Fit")
# true_vals = np.array(tuple(sim_cont.values()) + tuple(sim_line.values()) + tuple(sim_gauss.values()))
# guess_vals = np.array(tuple(guess_cont.values()) + tuple(guess_line.values()) + tuple(guess_gauss.values()))
# scipy_vals = opt_res.x
# astropy_vals = astropy_fitted_result.parameters
# cell_vals = np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T
# cell_text = np.round(np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T, 2).astype(str)
#
# plt.axis("off")
# plt.table(
# cellText=cell_text,
# cellColours=None,
# cellLoc="center",
# rowLabels=row_labels,
# rowColours=None,
# colLabels=column_labels,
# colColours=None,
# colLoc="center",
# bbox=[0, 0, 1, 1],
# )
#
# plt.show()
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
__all__ = ["minimize_func"]


def minimize_func(params, data_y, model_x, model_func, statistic_func):
def minimize_func(params, obs_spec, model_func, statistic_func):
"""
Minimization function.

Expand All @@ -32,5 +32,5 @@ def minimize_func(params, data_y, model_x, model_func, statistic_func):
`float`
The value to be optimized that compares the model to the data.
"""
model_y = model_func.evaluate(model_x, *params)
return statistic_func(data_y, model_y)
model_y = model_func.evaluate(obs_spec._spectral_axis.value, *params)
return statistic_func(obs_spec.data, model_y)
22 changes: 19 additions & 3 deletions sunkit_spex/models/instrument_response.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,31 @@
"""Module for model components required for instrument response models."""

import astropy.units as u
from astropy.modeling import Fittable1DModel, Parameter

__all__ = ["MatrixModel"]


class MatrixModel(Fittable1DModel):
def __init__(self, matrix):
# matrix = Parameter(description="The matrix with which to multiply the input.", fixed=True)

def __init__(self, matrix, input_axis, output_axis):
self.matrix = Parameter(default=matrix, description="The matrix with which to multiply the input.", fixed=True)
self.inputs_axis = input_axis
self.output_axis = output_axis
super().__init__()

def evaluate(self, model_y):
def evaluate(self, x):
# Requires input must have a specific dimensionality
return model_y @ self.matrix
return x @ self.matrix

@property
def input_units(self):
return {"x": u.ph}

@property
def output_units(self):
return {"y": u.ct}

def _parameter_units_for_data_units(self, inputs_unit, outputs_unit):
return {"x": u.ph, "y": u.ct}
Loading