diff --git a/examples/fitting_simulated_data.py b/examples/fitting_simulated_data.py index 11afa7ff..bdf41371 100644 --- a/examples/fitting_simulated_data.py +++ b/examples/fitting_simulated_data.py @@ -21,6 +21,8 @@ import astropy.units as u from astropy.modeling import fitting +from astropy.modeling.functional_models import Gaussian1D, Linear1D +from astropy.visualization import quantity_support from sunkit_spex.data.simulated_data import simulate_square_response_matrix from sunkit_spex.fitting.objective_functions.optimising_functions import minimize_func @@ -51,10 +53,12 @@ # use a straight line model for a continuum, Gaussian for a line ph_model = StraightLineModel(**sim_cont) + GaussianModel(**sim_line) -plt.figure() -plt.plot(ph_energies, ph_model(ph_energies)) -plt.title("Simulated Photon Spectrum") -plt.show() +with quantity_support(): + plt.figure() + plt.plot(ph_energies, ph_model(ph_energies)) + plt.xlabel(f"Energy [{ph_energies.unit}]") + plt.title("Simulated Photon Spectrum") + plt.show() ##################################################### # @@ -65,17 +69,23 @@ matrix=srm * u.ct / u.ph, input_axis=SpectralAxis(ph_energies), output_axis=SpectralAxis(ph_energies) ) -plt.figure() -plt.imshow( - srm, - origin="lower", - extent=(ph_energies[0].value, ph_energies[-1].value, ph_energies[0].value, ph_energies[-1].value), - norm=LogNorm(), -) -plt.ylabel("Photon Energies [keV]") -plt.xlabel("Count Energies [keV]") -plt.title("Simulated SRM") -plt.show() +with quantity_support(): + plt.figure() + plt.imshow( + srm_model.matrix.value, + origin="lower", + extent=( + srm_model.inputs_axis[0].value, + srm_model.inputs_axis[-1].value, + srm_model.output_axis[0].value, + srm_model.output_axis[-1].value, + ), + norm=LogNorm(), + ) + plt.ylabel(f"Photon Energies [{srm_model.inputs_axis.unit}]") + plt.xlabel(f"Count Energies [{srm_model.output_axis.unit}]") + plt.title("Simulated SRM") + plt.show() ##################################################### # @@ -99,25 +109,25 @@ sim_count_model + (2 * np_rand.random(sim_count_model.size) - 1) * np.sqrt(sim_count_model.value) * u.ct ) -obs_spec = Spectrum(sim_count_model_wn, spectral_axis=ph_energies) +obs_spec = Spectrum(sim_count_model_wn.reshape(-1), spectral_axis=ph_energies) ##################################################### # # Can plot all the different components in the simulated count spectrum -plt.figure() -plt.plot(ph_energies, (ph_model | srm_model)(ph_energies), label="photon model features") -plt.plot(ph_energies, GaussianModel(**sim_gauss)(ph_energies), label="gaussian feature") -plt.plot(ph_energies, sim_count_model, label="total sim. spectrum") -plt.plot(obs_spec._spectral_axis, obs_spec.data, label="total sim. spectrum + noise", lw=0.5) -plt.xlabel("Energy [keV]") -plt.ylabel("cts s$^{-1}$ keV$^{-1}$") -plt.title("Simulated Count Spectrum") -plt.legend() +with quantity_support(): + plt.figure() + plt.plot(ph_energies, (ph_model | srm_model)(ph_energies), label="photon model features") + plt.plot(ph_energies, GaussianModel(**sim_gauss)(ph_energies), label="gaussian feature") + plt.plot(ph_energies, sim_count_model, label="total sim. spectrum") + plt.plot(obs_spec._spectral_axis, obs_spec.data, label="total sim. spectrum + noise", lw=0.5) + plt.xlabel(f"Energy [{ph_energies.unit}]") + plt.title("Simulated Count Spectrum") + plt.legend() -plt.text(80, 170, "(ph_model(sl,in,am1,mn1,sd1) | srm)", ha="right", c="tab:blue", weight="bold") -plt.text(80, 150, "+ Gaussian(am2,mn2,sd2)", ha="right", c="tab:orange", weight="bold") -plt.show() + plt.text(80, 170, "(ph_model(sl,in,am1,mn1,sd1) | srm)", ha="right", c="tab:blue", weight="bold") + plt.text(80, 150, "+ Gaussian(am2,mn2,sd2)", ha="right", c="tab:orange", weight="bold") + plt.show() ##################################################### # @@ -142,14 +152,14 @@ opt_res = scipy_minimize(minimize_func, count_model_4fit.parameters, (obs_spec, count_model_4fit, chi_squared)) -plt.figure() -plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise") -plt.plot(ph_energies, count_model_4fit.evaluate(ph_energies.value, *opt_res.x), ls=":", label="model fit") -plt.xlabel("Energy [keV]") -plt.ylabel("cts s$^{-1}$ keV$^{-1}$") -plt.title("Simulated Count Spectrum Fit with Scipy") -plt.legend() -plt.show() +with quantity_support(): + plt.figure() + plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise") + plt.plot(ph_energies, count_model_4fit.evaluate(ph_energies.value, *opt_res.x), ls=":", label="model fit") + plt.xlabel(f"Energy [{ph_energies.unit}]") + plt.title("Simulated Count Spectrum Fit with Scipy") + plt.legend() + plt.show() ##################################################### @@ -158,12 +168,12 @@ # # Try and ensure we start fresh with new model definitions -ph_mod_4astropyfit = StraightLineModel(**guess_cont) + GaussianModel(**guess_line) -count_model_4astropyfit = (ph_mod_4fit | srm_model) + GaussianModel(**guess_gauss) +ph_mod_4astropyfit = Linear1D(**guess_cont) + Gaussian1D(**guess_line) +count_model_4astropyfit = (ph_mod_4astropyfit | srm_model) + Gaussian1D(**guess_gauss) astropy_fit = fitting.LevMarLSQFitter() -astropy_fitted_result = astropy_fit(count_model_4astropyfit, ph_energies, sim_count_model_wn) +astropy_fitted_result = astropy_fit(count_model_4astropyfit, ph_energies, obs_spec.data << obs_spec.unit) plt.figure() plt.plot(ph_energies, sim_count_model_wn, label="total sim. spectrum + noise") @@ -178,28 +188,28 @@ # # Display a table of the fitted results -plt.figure(layout="constrained") - -row_labels = tuple(sim_cont) + tuple(f"{p}1" for p in tuple(sim_line)) + tuple(f"{p}2" for p in tuple(sim_gauss)) -column_labels = ("True Values", "Guess Values", "Scipy Fit", "Astropy Fit") -true_vals = np.array(tuple(sim_cont.values()) + tuple(sim_line.values()) + tuple(sim_gauss.values())) -guess_vals = np.array(tuple(guess_cont.values()) + tuple(guess_line.values()) + tuple(guess_gauss.values())) -scipy_vals = opt_res.x -astropy_vals = astropy_fitted_result.parameters -cell_vals = np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T -cell_text = np.round(np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T, 2).astype(str) - -plt.axis("off") -plt.table( - cellText=cell_text, - cellColours=None, - cellLoc="center", - rowLabels=row_labels, - rowColours=None, - colLabels=column_labels, - colColours=None, - colLoc="center", - bbox=[0, 0, 1, 1], -) - -plt.show() +# plt.figure(layout="constrained") +# +# row_labels = tuple(sim_cont) + tuple(f"{p}1" for p in tuple(sim_line)) + tuple(f"{p}2" for p in tuple(sim_gauss)) +# column_labels = ("True Values", "Guess Values", "Scipy Fit", "Astropy Fit") +# true_vals = np.array(tuple(sim_cont.values()) + tuple(sim_line.values()) + tuple(sim_gauss.values())) +# guess_vals = np.array(tuple(guess_cont.values()) + tuple(guess_line.values()) + tuple(guess_gauss.values())) +# scipy_vals = opt_res.x +# astropy_vals = astropy_fitted_result.parameters +# cell_vals = np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T +# cell_text = np.round(np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T, 2).astype(str) +# +# plt.axis("off") +# plt.table( +# cellText=cell_text, +# cellColours=None, +# cellLoc="center", +# rowLabels=row_labels, +# rowColours=None, +# colLabels=column_labels, +# colColours=None, +# colLoc="center", +# bbox=[0, 0, 1, 1], +# ) +# +# plt.show() diff --git a/sunkit_spex/models/instrument_response.py b/sunkit_spex/models/instrument_response.py index f90d3734..da30b3c9 100644 --- a/sunkit_spex/models/instrument_response.py +++ b/sunkit_spex/models/instrument_response.py @@ -1,17 +1,31 @@ """Module for model components required for instrument response models.""" +import astropy.units as u from astropy.modeling import Fittable1DModel, Parameter __all__ = ["MatrixModel"] class MatrixModel(Fittable1DModel): + # matrix = Parameter(description="The matrix with which to multiply the input.", fixed=True) + def __init__(self, matrix, input_axis, output_axis): self.matrix = Parameter(default=matrix, description="The matrix with which to multiply the input.", fixed=True) self.inputs_axis = input_axis self.output_axis = output_axis super().__init__() - def evaluate(self, model_y): + def evaluate(self, x): # Requires input must have a specific dimensionality - return model_y @ self.matrix + return x @ self.matrix + + @property + def input_units(self): + return {"x": u.ph} + + @property + def output_units(self): + return {"y": u.ct} + + def _parameter_units_for_data_units(self, inputs_unit, outputs_unit): + return {"x": u.ph, "y": u.ct}