Skip to content

Commit

Permalink
Scipy fitting works can't get astropy fitting to work
Browse files Browse the repository at this point in the history
  • Loading branch information
samaloney committed Oct 4, 2024
1 parent a722920 commit 30e3a84
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 66 deletions.
138 changes: 74 additions & 64 deletions examples/fitting_simulated_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import astropy.units as u
from astropy.modeling import fitting
from astropy.modeling.functional_models import Gaussian1D, Linear1D
from astropy.visualization import quantity_support

from sunkit_spex.data.simulated_data import simulate_square_response_matrix
from sunkit_spex.fitting.objective_functions.optimising_functions import minimize_func
Expand Down Expand Up @@ -51,10 +53,12 @@
# use a straight line model for a continuum, Gaussian for a line
ph_model = StraightLineModel(**sim_cont) + GaussianModel(**sim_line)

plt.figure()
plt.plot(ph_energies, ph_model(ph_energies))
plt.title("Simulated Photon Spectrum")
plt.show()
with quantity_support():
plt.figure()
plt.plot(ph_energies, ph_model(ph_energies))
plt.xlabel(f"Energy [{ph_energies.unit}]")
plt.title("Simulated Photon Spectrum")
plt.show()

#####################################################
#
Expand All @@ -65,17 +69,23 @@
matrix=srm * u.ct / u.ph, input_axis=SpectralAxis(ph_energies), output_axis=SpectralAxis(ph_energies)
)

plt.figure()
plt.imshow(
srm,
origin="lower",
extent=(ph_energies[0].value, ph_energies[-1].value, ph_energies[0].value, ph_energies[-1].value),
norm=LogNorm(),
)
plt.ylabel("Photon Energies [keV]")
plt.xlabel("Count Energies [keV]")
plt.title("Simulated SRM")
plt.show()
with quantity_support():
plt.figure()
plt.imshow(
srm_model.matrix.value,
origin="lower",
extent=(
srm_model.inputs_axis[0].value,
srm_model.inputs_axis[-1].value,
srm_model.output_axis[0].value,
srm_model.output_axis[-1].value,
),
norm=LogNorm(),
)
plt.ylabel(f"Photon Energies [{srm_model.inputs_axis.unit}]")
plt.xlabel(f"Count Energies [{srm_model.output_axis.unit}]")
plt.title("Simulated SRM")
plt.show()

#####################################################
#
Expand All @@ -99,25 +109,25 @@
sim_count_model + (2 * np_rand.random(sim_count_model.size) - 1) * np.sqrt(sim_count_model.value) * u.ct
)

obs_spec = Spectrum(sim_count_model_wn, spectral_axis=ph_energies)
obs_spec = Spectrum(sim_count_model_wn.reshape(-1), spectral_axis=ph_energies)

#####################################################
#
# Can plot all the different components in the simulated count spectrum

plt.figure()
plt.plot(ph_energies, (ph_model | srm_model)(ph_energies), label="photon model features")
plt.plot(ph_energies, GaussianModel(**sim_gauss)(ph_energies), label="gaussian feature")
plt.plot(ph_energies, sim_count_model, label="total sim. spectrum")
plt.plot(obs_spec._spectral_axis, obs_spec.data, label="total sim. spectrum + noise", lw=0.5)
plt.xlabel("Energy [keV]")
plt.ylabel("cts s$^{-1}$ keV$^{-1}$")
plt.title("Simulated Count Spectrum")
plt.legend()
with quantity_support():
plt.figure()
plt.plot(ph_energies, (ph_model | srm_model)(ph_energies), label="photon model features")
plt.plot(ph_energies, GaussianModel(**sim_gauss)(ph_energies), label="gaussian feature")
plt.plot(ph_energies, sim_count_model, label="total sim. spectrum")
plt.plot(obs_spec._spectral_axis, obs_spec.data, label="total sim. spectrum + noise", lw=0.5)
plt.xlabel(f"Energy [{ph_energies.unit}]")
plt.title("Simulated Count Spectrum")
plt.legend()

plt.text(80, 170, "(ph_model(sl,in,am1,mn1,sd1) | srm)", ha="right", c="tab:blue", weight="bold")
plt.text(80, 150, "+ Gaussian(am2,mn2,sd2)", ha="right", c="tab:orange", weight="bold")
plt.show()
plt.text(80, 170, "(ph_model(sl,in,am1,mn1,sd1) | srm)", ha="right", c="tab:blue", weight="bold")
plt.text(80, 150, "+ Gaussian(am2,mn2,sd2)", ha="right", c="tab:orange", weight="bold")
plt.show()

#####################################################
#
Expand All @@ -142,14 +152,14 @@

opt_res = scipy_minimize(minimize_func, count_model_4fit.parameters, (obs_spec, count_model_4fit, chi_squared))

plt.figure()
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise")
plt.plot(ph_energies, count_model_4fit.evaluate(ph_energies.value, *opt_res.x), ls=":", label="model fit")
plt.xlabel("Energy [keV]")
plt.ylabel("cts s$^{-1}$ keV$^{-1}$")
plt.title("Simulated Count Spectrum Fit with Scipy")
plt.legend()
plt.show()
with quantity_support():
plt.figure()
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise")
plt.plot(ph_energies, count_model_4fit.evaluate(ph_energies.value, *opt_res.x), ls=":", label="model fit")
plt.xlabel(f"Energy [{ph_energies.unit}]")
plt.title("Simulated Count Spectrum Fit with Scipy")
plt.legend()
plt.show()


#####################################################
Expand All @@ -158,12 +168,12 @@
#
# Try and ensure we start fresh with new model definitions

ph_mod_4astropyfit = StraightLineModel(**guess_cont) + GaussianModel(**guess_line)
count_model_4astropyfit = (ph_mod_4fit | srm_model) + GaussianModel(**guess_gauss)
ph_mod_4astropyfit = Linear1D(**guess_cont) + Gaussian1D(**guess_line)
count_model_4astropyfit = (ph_mod_4astropyfit | srm_model) + Gaussian1D(**guess_gauss)

astropy_fit = fitting.LevMarLSQFitter()

astropy_fitted_result = astropy_fit(count_model_4astropyfit, ph_energies, sim_count_model_wn)
astropy_fitted_result = astropy_fit(count_model_4astropyfit, ph_energies, obs_spec.data << obs_spec.unit)

plt.figure()
plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise")
Expand All @@ -178,28 +188,28 @@
#
# Display a table of the fitted results

plt.figure(layout="constrained")

row_labels = tuple(sim_cont) + tuple(f"{p}1" for p in tuple(sim_line)) + tuple(f"{p}2" for p in tuple(sim_gauss))
column_labels = ("True Values", "Guess Values", "Scipy Fit", "Astropy Fit")
true_vals = np.array(tuple(sim_cont.values()) + tuple(sim_line.values()) + tuple(sim_gauss.values()))
guess_vals = np.array(tuple(guess_cont.values()) + tuple(guess_line.values()) + tuple(guess_gauss.values()))
scipy_vals = opt_res.x
astropy_vals = astropy_fitted_result.parameters
cell_vals = np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T
cell_text = np.round(np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T, 2).astype(str)

plt.axis("off")
plt.table(
cellText=cell_text,
cellColours=None,
cellLoc="center",
rowLabels=row_labels,
rowColours=None,
colLabels=column_labels,
colColours=None,
colLoc="center",
bbox=[0, 0, 1, 1],
)

plt.show()
# plt.figure(layout="constrained")
#
# row_labels = tuple(sim_cont) + tuple(f"{p}1" for p in tuple(sim_line)) + tuple(f"{p}2" for p in tuple(sim_gauss))
# column_labels = ("True Values", "Guess Values", "Scipy Fit", "Astropy Fit")
# true_vals = np.array(tuple(sim_cont.values()) + tuple(sim_line.values()) + tuple(sim_gauss.values()))
# guess_vals = np.array(tuple(guess_cont.values()) + tuple(guess_line.values()) + tuple(guess_gauss.values()))
# scipy_vals = opt_res.x
# astropy_vals = astropy_fitted_result.parameters
# cell_vals = np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T
# cell_text = np.round(np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T, 2).astype(str)
#
# plt.axis("off")
# plt.table(
# cellText=cell_text,
# cellColours=None,
# cellLoc="center",
# rowLabels=row_labels,
# rowColours=None,
# colLabels=column_labels,
# colColours=None,
# colLoc="center",
# bbox=[0, 0, 1, 1],
# )
#
# plt.show()
18 changes: 16 additions & 2 deletions sunkit_spex/models/instrument_response.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,31 @@
"""Module for model components required for instrument response models."""

import astropy.units as u
from astropy.modeling import Fittable1DModel, Parameter

__all__ = ["MatrixModel"]


class MatrixModel(Fittable1DModel):
# matrix = Parameter(description="The matrix with which to multiply the input.", fixed=True)

def __init__(self, matrix, input_axis, output_axis):
self.matrix = Parameter(default=matrix, description="The matrix with which to multiply the input.", fixed=True)
self.inputs_axis = input_axis
self.output_axis = output_axis
super().__init__()

def evaluate(self, model_y):
def evaluate(self, x):
# Requires input must have a specific dimensionality
return model_y @ self.matrix
return x @ self.matrix

@property
def input_units(self):
return {"x": u.ph}

@property
def output_units(self):
return {"y": u.ct}

def _parameter_units_for_data_units(self, inputs_unit, outputs_unit):
return {"x": u.ph, "y": u.ct}

0 comments on commit 30e3a84

Please sign in to comment.