From ae8211223bbd8e136b05e7bc16745a891fbcaec2 Mon Sep 17 00:00:00 2001 From: Kristopher Cooper Date: Thu, 27 Jun 2024 15:30:25 -0500 Subject: [PATCH] Small edits. --- .../fitting/examples/simple_fitting.py | 65 +++++++++++-------- 1 file changed, 39 insertions(+), 26 deletions(-) diff --git a/sunkit_spex/fitting/examples/simple_fitting.py b/sunkit_spex/fitting/examples/simple_fitting.py index f7fd7053..e62103e2 100644 --- a/sunkit_spex/fitting/examples/simple_fitting.py +++ b/sunkit_spex/fitting/examples/simple_fitting.py @@ -28,8 +28,6 @@ # this should keep "random" stuff is the same ech run np.random.seed(seed=10) -# get some fake photon energies - def photon_energies(start, stop, inc): """ Get a `ndarray` of energies. """ @@ -70,7 +68,9 @@ def response_matrix(photon_energies): _y = 200 __y = 30 if c > _y+100: - off_diag[_y-__y//2:_y+__y//2] = np.arange(2*(__y//2))*np.random.rand(2*(__y//2))*5e-4 + off_diag[_y-__y//2:_y+__y//2] = (np.arange(2*(__y//2)) + * np.random.rand(2*(__y//2)) + * 5e-4) # put these features in the fake_srm row and normalize r[:off_diag.size] = off_diag @@ -129,8 +129,13 @@ def plot_fake_count_spectrum_fit(axis, title="Fake Count Spectrum Fit"): """ Plot the fitted result. """ - axis.plot(photon_energies, total_count_spectrum_wnoise, label="total fake spectrum + noise") - axis.plot(ph_energies, fitted_count_spectrum, ls=":", label="model fit") + axis.plot(photon_energies, + total_count_spectrum_wnoise, + label="total fake spectrum + noise") + axis.plot(ph_energies, + fitted_count_spectrum, + ls=":", + label="model fit") axis.set_xlabel("Energy [keV]") axis.set_ylabel("cts s$^{-1}$ keV$^{-1}$") axis.set_title(title) @@ -152,10 +157,10 @@ def plot_table_of_results(ax, row_labels, column_labels, cell_text): if __name__ == "__main__": - # ******************************************************************* + # ****************************************************************** # Start by creating fake data and instrument. # This would all be provided by a given observation. - # ******************************************************************* + # ****************************************************************** # define the photon energies start, inc = 1.6, 0.04 @@ -166,8 +171,8 @@ def plot_table_of_results(ax, row_labels, column_labels, cell_text): fake_cont = {"m": -1, "c": 100} fake_line = {"a": 100, "b": 30, "c": 2} # use a straight line model for a continuum, Gaussian for a line - ph_model = StraightLinePhotonModel(**fake_cont) + \ - GaussianPhotonModel(**fake_line) + ph_model = (StraightLinePhotonModel(**fake_cont) + + GaussianPhotonModel(**fake_line)) # now want a response matrix srm = response_matrix(ph_energies) @@ -181,23 +186,27 @@ def plot_table_of_results(ax, row_labels, column_labels, cell_text): # generate fake count data to (almost) fit fake_count_model = ct_model(ph_energies) # add some noise - fake_count_model_wn = fake_count_model + \ - (2*np.random.rand(fake_count_model.size)-1)*np.sqrt(fake_count_model) + fake_count_model_wn = (fake_count_model + + (2*np.random.rand(fake_count_model.size)-1) + * np.sqrt(fake_count_model)) - # ******************************************************************* + # ****************************************************************** # Now we have the fake data, let's start setting up to fit it - # ******************************************************************* + # ****************************************************************** # get some initial guesses that are off from the fake data above - guess_cont = {"m": -0.5, "c": 80} # original {"m":-1, "c":100} - guess_line = {"a": 150, "b": 32, "c": 5} # original {"a":100, "b":30, "c":2} - guess_gauss = {"a": 350, "b": 39, "c": 0.5} # original {"a":70, "b":40, "c":2} + # original {"m":-1, "c":100} + guess_cont = {"m": -0.5, "c": 80} + # original {"a":100, "b":30, "c":2} + guess_line = {"a": 150, "b": 32, "c": 5} + # original {"a":70, "b":40, "c":2} + guess_gauss = {"a": 350, "b": 39, "c": 0.5} # define a new model since we have a rough idea of the mode we should use - ph_mod_4fit = StraightLinePhotonModel(**guess_cont) + \ - GaussianPhotonModel(**guess_line) - count_model_4fit = (ph_mod_4fit | srm_model) + \ - GaussianCountModel(**guess_gauss) + ph_mod_4fit = (StraightLinePhotonModel(**guess_cont) + + GaussianPhotonModel(**guess_line)) + count_model_4fit = ((ph_mod_4fit | srm_model) + + GaussianCountModel(**guess_gauss)) # let's fit the fake data opt_res = scipy_minimize(minimize_func, @@ -207,9 +216,9 @@ def plot_table_of_results(ax, row_labels, column_labels, cell_text): count_model_4fit, chi_squared)) - # ******************************************************************* + # ****************************************************************** # Now try and fit with Astropy native fitting infrastructure - # ******************************************************************* + # ****************************************************************** # try and ensure we start fresh with new model definitions ph_mod_4astropyfit = StraightLinePhotonModel(**guess_cont) + \ @@ -223,9 +232,9 @@ def plot_table_of_results(ax, row_labels, column_labels, cell_text): ph_energies, fake_count_model_wn) - # ******************************************************************* + # ****************************************************************** # Plot the results - # ******************************************************************* + # ****************************************************************** fig = plt.figure(layout="constrained", figsize=(14, 7)) @@ -269,7 +278,8 @@ def plot_table_of_results(ax, row_labels, column_labels, cell_text): plot_fake_count_spectrum_fit(ax4, ph_energies, fake_count_model_wn, - count_model_4fit.evaluate(ph_energies, *opt_res.x), + count_model_4fit.evaluate(ph_energies, + *opt_res.x), title="Fake Count Spectrum Fit with Scipy") # the count spectrum fitted with Astropy @@ -295,7 +305,10 @@ def plot_table_of_results(ax, row_labels, column_labels, cell_text): scipy_vals = opt_res.x astropy_vals = astropy_fitted_result.parameters cell_vals = np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T - cell_text = np.round(np.vstack((true_vals, guess_vals, scipy_vals, astropy_vals)).T, 2).astype(str) + cell_text = np.round(np.vstack((true_vals, + guess_vals, + scipy_vals, + astropy_vals)).T, 2).astype(str) plot_table_of_results(ax6, row_labels, column_labels, cell_text) if SAVE_FIG: