Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
samaloney committed Sep 26, 2024
1 parent c340a00 commit a722920
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 22 deletions.
44 changes: 26 additions & 18 deletions examples/fitting_simulated_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
from matplotlib.colors import LogNorm

import astropy.units as u
from astropy.modeling import fitting

from sunkit_spex.data.simulated_data import simulate_square_response_matrix
Expand All @@ -27,6 +28,8 @@
from sunkit_spex.fitting.statistics.gaussian import chi_squared
from sunkit_spex.models.instrument_response import MatrixModel
from sunkit_spex.models.models import GaussianModel, StraightLineModel
from sunkit_spex.spectrum import Spectrum
from sunkit_spex.spectrum.spectrum import SpectralAxis

#####################################################
#
Expand All @@ -37,21 +40,19 @@

start, inc = 1.6, 0.04
stop = 80 + inc / 2
ph_energies = np.arange(start, stop, inc)
ph_energies = np.arange(start, stop, inc) * u.keV

#####################################################
#
# Let's start making a simulated photon spectrum

sim_cont = {"slope": -1, "intercept": 100}
sim_line = {"amplitude": 100, "mean": 30, "stddev": 2}
sim_cont = {"slope": -1 * u.ph / u.keV, "intercept": 100 * u.ph}
sim_line = {"amplitude": 100 * u.ph, "mean": 30 * u.keV, "stddev": 2 * u.keV}
# use a straight line model for a continuum, Gaussian for a line
ph_model = StraightLineModel(**sim_cont) + GaussianModel(**sim_line)

plt.figure()
plt.plot(ph_energies, ph_model(ph_energies))
plt.xlabel("Energy [keV]")
plt.ylabel("ph s$^{-1}$ cm$^{-2}$ keV$^{-1}$")
plt.title("Simulated Photon Spectrum")
plt.show()

Expand All @@ -60,11 +61,16 @@
# Now want a response matrix

srm = simulate_square_response_matrix(ph_energies.size)
srm_model = MatrixModel(matrix=srm)
srm_model = MatrixModel(
matrix=srm * u.ct / u.ph, input_axis=SpectralAxis(ph_energies), output_axis=SpectralAxis(ph_energies)
)

plt.figure()
plt.imshow(
srm, origin="lower", extent=[ph_energies[0], ph_energies[-1], ph_energies[0], ph_energies[-1]], norm=LogNorm()
srm,
origin="lower",
extent=(ph_energies[0].value, ph_energies[-1].value, ph_energies[0].value, ph_energies[-1].value),
norm=LogNorm(),
)
plt.ylabel("Photon Energies [keV]")
plt.xlabel("Count Energies [keV]")
Expand All @@ -75,21 +81,25 @@
#
# Start work on a count model

sim_gauss = {"amplitude": 70, "mean": 40, "stddev": 2}
sim_gauss = {"amplitude": 70 * u.ct, "mean": 40 * u.keV, "stddev": 2 * u.keV}
# the brackets are very necessary
ct_model = (ph_model | srm_model) + GaussianModel(**sim_gauss)

#####################################################
#
# Generate simulated count data to (almost) fit

sim_count_model = ct_model(ph_energies)
sim_count_model = ct_model(srm_model.inputs_axis)

#####################################################
#
# Add some noise
np_rand = np.random.default_rng(seed=10)
sim_count_model_wn = sim_count_model + (2 * np_rand.random(sim_count_model.size) - 1) * np.sqrt(sim_count_model)
sim_count_model_wn = (
sim_count_model + (2 * np_rand.random(sim_count_model.size) - 1) * np.sqrt(sim_count_model.value) * u.ct
)

obs_spec = Spectrum(sim_count_model_wn, spectral_axis=ph_energies)

#####################################################
#
Expand All @@ -99,7 +109,7 @@
plt.plot(ph_energies, (ph_model | srm_model)(ph_energies), label="photon model features")
plt.plot(ph_energies, GaussianModel(**sim_gauss)(ph_energies), label="gaussian feature")
plt.plot(ph_energies, sim_count_model, label="total sim. spectrum")
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise", lw=0.5)
plt.plot(obs_spec._spectral_axis, obs_spec.data, label="total sim. spectrum + noise", lw=0.5)
plt.xlabel("Energy [keV]")
plt.ylabel("cts s$^{-1}$ keV$^{-1}$")
plt.title("Simulated Count Spectrum")
Expand All @@ -115,9 +125,9 @@
#
# Get some initial guesses that are off from the simulated data above

guess_cont = {"slope": -0.5, "intercept": 80}
guess_line = {"amplitude": 150, "mean": 32, "stddev": 5}
guess_gauss = {"amplitude": 350, "mean": 39, "stddev": 0.5}
guess_cont = {"slope": -0.5 * u.ph / u.keV, "intercept": 80 * u.ph}
guess_line = {"amplitude": 150 * u.ph, "mean": 32 * u.keV, "stddev": 5 * u.keV}
guess_gauss = {"amplitude": 350 * u.ct, "mean": 39 * u.keV, "stddev": 0.5 * u.keV}

#####################################################
#
Expand All @@ -130,13 +140,11 @@
#
# Let's fit the simulated data and plot the result

opt_res = scipy_minimize(
minimize_func, count_model_4fit.parameters, (sim_count_model_wn, ph_energies, count_model_4fit, chi_squared)
)
opt_res = scipy_minimize(minimize_func, count_model_4fit.parameters, (obs_spec, count_model_4fit, chi_squared))

plt.figure()
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise")
plt.plot(ph_energies, count_model_4fit.evaluate(ph_energies, *opt_res.x), ls=":", label="model fit")
plt.plot(ph_energies, count_model_4fit.evaluate(ph_energies.value, *opt_res.x), ls=":", label="model fit")
plt.xlabel("Energy [keV]")
plt.ylabel("cts s$^{-1}$ keV$^{-1}$")
plt.title("Simulated Count Spectrum Fit with Scipy")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
__all__ = ["minimize_func"]


def minimize_func(params, data_y, model_x, model_func, statistic_func):
def minimize_func(params, obs_spec, model_func, statistic_func):
"""
Minimization function.
Expand All @@ -32,5 +32,5 @@ def minimize_func(params, data_y, model_x, model_func, statistic_func):
`float`
The value to be optimized that compares the model to the data.
"""
model_y = model_func.evaluate(model_x, *params)
return statistic_func(data_y, model_y)
model_y = model_func.evaluate(obs_spec._spectral_axis.value, *params)
return statistic_func(obs_spec.data, model_y)
4 changes: 3 additions & 1 deletion sunkit_spex/models/instrument_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@


class MatrixModel(Fittable1DModel):
def __init__(self, matrix):
def __init__(self, matrix, input_axis, output_axis):
self.matrix = Parameter(default=matrix, description="The matrix with which to multiply the input.", fixed=True)
self.inputs_axis = input_axis
self.output_axis = output_axis
super().__init__()

def evaluate(self, model_y):
Expand Down

0 comments on commit a722920

Please sign in to comment.