Skip to content

Commit

Permalink
Fix fitting when using per_ptjet and ROOT fits
Browse files Browse the repository at this point in the history
  • Loading branch information
qgp committed Aug 31, 2024
1 parent 73c2722 commit 83263b4
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions machine_learning_hep/analysis/analyzer_jets.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def __init__(self, datap, case, typean, period):
self.roo_ws = {}
self.roo_ws_ptjet = {}
self.roows = {}
self.roows_ptjet = {}

#region helpers
def _save_canvas(self, canvas, filename):
Expand Down Expand Up @@ -399,10 +400,9 @@ def fit(self):
self.logger.debug("Opening histogram %s.", name_histo)
if not (h := rfile.Get(name_histo)):
self.logger.critical("Histogram %s not found.", name_histo)
for iptjet, ipt in itertools.product(itertools.chain((None,), range(0, get_nbins(h, 1))),
for iptjet, ipt in itertools.product(itertools.chain((None,), range(get_nbins(h, 1))),
range(get_nbins(h, 2))):
self.logger.debug('fitting %s: %s, %i', level, iptjet, ipt)
roows = self.roows.get(ipt)
axis_ptjet = get_axis(h, 1)
cuts_proj = {2: (ipt+1, ipt+1)}
if iptjet is not None:
Expand All @@ -411,15 +411,15 @@ def fit(self):
else:
jetptlabel = ''
h_invmass = project_hist(h, [0], cuts_proj)
if h_invmass.GetEntries() < 100: # TODO: reconsider criterion
self.logger.error('Not enough entries to fit %s iptjet %s ipt %d',
level, iptjet, ipt)
continue
# Rebin
if (n_rebin := self.cfg("n_rebin", 1)) != 1:
h_invmass.Rebin(n_rebin)
ptrange = (self.bins_candpt[ipt], self.bins_candpt[ipt+1])
if self.cfg('mass_fit'):
if h_invmass.GetEntries() < 100: # TODO: reconsider criterion
self.logger.error('Not enough entries to fit %s iptjet %s ipt %d',
level, iptjet, ipt)
continue
if self.cfg('mass_fit') and iptjet is None:
fit_res, _, func_bkg = self._fit_mass(
h_invmass,
f'fit/h_mass_fitted_pthf-{ptrange[0]}-{ptrange[1]}_{level}.png')
Expand All @@ -440,19 +440,15 @@ def fit(self):
fitcfg = entry
break
self.logger.debug("Using fit config for %i: %s", ipt, fitcfg)
# check
if iptjet is not None and not fitcfg.get('per_ptjet'):
continue
if h_invmass.GetEntries() < 100: # TODO: reconsider criterion
self.logger.warning('Not enough entries to fit for %s iptjet %s ipt %d',
level, iptjet, ipt)
continue
# TODO: link datasel to fit stage
if datasel := fitcfg.get('datasel'):
hist_name = f'h_mass-ptjet-pthf_{datasel}'
if not (hsel := rfile.Get(hist_name)):
self.logger.critical("Failed to get histogram %s", hist_name)
h_invmass = project_hist(hsel, [0], cuts_proj)
roows = self.roows.get(ipt) if iptjet is None else self.roows_ptjet.get((iptjet, ipt))
for par in fitcfg.get('fix_params', []):
if var := roows.var(par):
var.setConstant(True)
Expand All @@ -472,10 +468,11 @@ def fit(self):
# roo_ws.Print()
# TODO: save snapshot per level
# roo_ws.saveSnapshot(level, None)
self.roows[ipt] = roo_ws
if iptjet is not None:
self.roows_ptjet[(iptjet, ipt)] = roo_ws
self.roo_ws_ptjet[level][iptjet][ipt] = roo_ws
else:
self.roows[ipt] = roo_ws
self.roo_ws[level][ipt] = roo_ws
# TODO: take parameter names from DB
if level in ('data', 'mc'):
Expand Down

0 comments on commit 83263b4

Please sign in to comment.