Skip to content

Commit

Permalink
Small edits.
Browse files Browse the repository at this point in the history
  • Loading branch information
Kristopher Cooper committed Jun 27, 2024
1 parent 9157d23 commit ae82112
Showing 1 changed file with 39 additions and 26 deletions.
65 changes: 39 additions & 26 deletions sunkit_spex/fitting/examples/simple_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
# this should keep "random" stuff is the same ech run
np.random.seed(seed=10)

# get some fake photon energies


def photon_energies(start, stop, inc):
""" Get a `ndarray` of energies. """
Expand Down Expand Up @@ -70,7 +68,9 @@ def response_matrix(photon_energies):
_y = 200
__y = 30
if c > _y+100:
off_diag[_y-__y//2:_y+__y//2] = np.arange(2*(__y//2))*np.random.rand(2*(__y//2))*5e-4
off_diag[_y-__y//2:_y+__y//2] = (np.arange(2*(__y//2))
* np.random.rand(2*(__y//2))
* 5e-4)

# put these features in the fake_srm row and normalize
r[:off_diag.size] = off_diag
Expand Down Expand Up @@ -129,8 +129,13 @@ def plot_fake_count_spectrum_fit(axis,
title="Fake Count Spectrum Fit"):
""" Plot the fitted result. """

axis.plot(photon_energies, total_count_spectrum_wnoise, label="total fake spectrum + noise")
axis.plot(ph_energies, fitted_count_spectrum, ls=":", label="model fit")
axis.plot(photon_energies,
total_count_spectrum_wnoise,
label="total fake spectrum + noise")
axis.plot(ph_energies,
fitted_count_spectrum,
ls=":",
label="model fit")
axis.set_xlabel("Energy [keV]")
axis.set_ylabel("cts s$^{-1}$ keV$^{-1}$")
axis.set_title(title)
Expand All @@ -152,10 +157,10 @@ def plot_table_of_results(ax, row_labels, column_labels, cell_text):


if __name__ == "__main__":
# *******************************************************************
# ******************************************************************
# Start by creating fake data and instrument.
# This would all be provided by a given observation.
# *******************************************************************
# ******************************************************************

# define the photon energies
start, inc = 1.6, 0.04
Expand All @@ -166,8 +171,8 @@ def plot_table_of_results(ax, row_labels, column_labels, cell_text):
fake_cont = {"m": -1, "c": 100}
fake_line = {"a": 100, "b": 30, "c": 2}
# use a straight line model for a continuum, Gaussian for a line
ph_model = StraightLinePhotonModel(**fake_cont) + \
GaussianPhotonModel(**fake_line)
ph_model = (StraightLinePhotonModel(**fake_cont)
+ GaussianPhotonModel(**fake_line))

# now want a response matrix
srm = response_matrix(ph_energies)
Expand All @@ -181,23 +186,27 @@ def plot_table_of_results(ax, row_labels, column_labels, cell_text):
# generate fake count data to (almost) fit
fake_count_model = ct_model(ph_energies)
# add some noise
fake_count_model_wn = fake_count_model + \
(2*np.random.rand(fake_count_model.size)-1)*np.sqrt(fake_count_model)
fake_count_model_wn = (fake_count_model
+ (2*np.random.rand(fake_count_model.size)-1)
* np.sqrt(fake_count_model))

# *******************************************************************
# ******************************************************************
# Now we have the fake data, let's start setting up to fit it
# *******************************************************************
# ******************************************************************

# get some initial guesses that are off from the fake data above
guess_cont = {"m": -0.5, "c": 80} # original {"m":-1, "c":100}
guess_line = {"a": 150, "b": 32, "c": 5} # original {"a":100, "b":30, "c":2}
guess_gauss = {"a": 350, "b": 39, "c": 0.5} # original {"a":70, "b":40, "c":2}
# original {"m":-1, "c":100}
guess_cont = {"m": -0.5, "c": 80}
# original {"a":100, "b":30, "c":2}
guess_line = {"a": 150, "b": 32, "c": 5}
# original {"a":70, "b":40, "c":2}
guess_gauss = {"a": 350, "b": 39, "c": 0.5}

# define a new model since we have a rough idea of the mode we should use
ph_mod_4fit = StraightLinePhotonModel(**guess_cont) + \
GaussianPhotonModel(**guess_line)
count_model_4fit = (ph_mod_4fit | srm_model) + \
GaussianCountModel(**guess_gauss)
ph_mod_4fit = (StraightLinePhotonModel(**guess_cont)
+ GaussianPhotonModel(**guess_line))
count_model_4fit = ((ph_mod_4fit | srm_model)
+ GaussianCountModel(**guess_gauss))

# let's fit the fake data
opt_res = scipy_minimize(minimize_func,
Expand All @@ -207,9 +216,9 @@ def plot_table_of_results(ax, row_labels, column_labels, cell_text):
count_model_4fit,
chi_squared))

# *******************************************************************
# ******************************************************************
# Now try and fit with Astropy native fitting infrastructure
# *******************************************************************
# ******************************************************************

# try and ensure we start fresh with new model definitions
ph_mod_4astropyfit = StraightLinePhotonModel(**guess_cont) + \
Expand All @@ -223,9 +232,9 @@ def plot_table_of_results(ax, row_labels, column_labels, cell_text):
ph_energies,
fake_count_model_wn)

# *******************************************************************
# ******************************************************************
# Plot the results
# *******************************************************************
# ******************************************************************

fig = plt.figure(layout="constrained", figsize=(14, 7))

Expand Down Expand Up @@ -269,7 +278,8 @@ def plot_table_of_results(ax, row_labels, column_labels, cell_text):
plot_fake_count_spectrum_fit(ax4,
ph_energies,
fake_count_model_wn,
count_model_4fit.evaluate(ph_energies, *opt_res.x),
count_model_4fit.evaluate(ph_energies,
*opt_res.x),
title="Fake Count Spectrum Fit with Scipy")

# the count spectrum fitted with Astropy
Expand All @@ -295,7 +305,10 @@ def plot_table_of_results(ax, row_labels, column_labels, cell_text):
scipy_vals = opt_res.x
astropy_vals = astropy_fitted_result.parameters
cell_vals = np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T
cell_text = np.round(np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T, 2).astype(str)
cell_text = np.round(np.vstack((true_vals,
guess_vals,
scipy_vals,
astropy_vals)).T, 2).astype(str)
plot_table_of_results(ax6, row_labels, column_labels, cell_text)

if SAVE_FIG:
Expand Down

0 comments on commit ae82112

Please sign in to comment.