Skip to content

Commit

Permalink
Fix per_ptjet fit
Browse files Browse the repository at this point in the history
  • Loading branch information
vkucera authored and qgp committed Sep 5, 2024
1 parent a6c5cce commit aa90ba9
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions machine_learning_hep/analysis/analyzer_jets.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,8 @@ def fit(self):
continue
roows = self.roows.get(ipt) if iptjet is None else self.roows_ptjet.get((iptjet, ipt))
if roows is None and level != self.fit_levels[0]:
self.logger.warning('missing previous fit result, skipping %s iptjet %s ipt %d',
level, iptjet, ipt)
continue
self.logger.critical('missing previous fit result, cannot fit %s iptjet %s ipt %d',
level, iptjet, ipt)
for par in fitcfg.get('fix_params', []):
if var := roows.var(par):
var.setConstant(True)
Expand All @@ -494,11 +493,16 @@ def fit(self):
# TODO: save snapshot per level
# roo_ws.saveSnapshot(level, None)
if iptjet is not None:
self.logger.debug("Setting roows_ptjet for %s iptjet %s ipt %d", level, iptjet, ipt)
self.roows_ptjet[(iptjet, ipt)] = roo_ws
self.roo_ws_ptjet[level][iptjet][ipt] = roo_ws
else:
self.logger.debug("Setting roows for %s iptjet %s ipt %d", level, iptjet, ipt)
self.roows[ipt] = roo_ws
self.roo_ws[level][ipt] = roo_ws
for jptjet in range(get_nbins(h, 1)):
self.roows_ptjet[(jptjet, ipt)] = roo_ws.Clone()
self.roo_ws_ptjet[level][jptjet][ipt] = roo_ws.Clone()
# TODO: take parameter names from DB
if level in ('data', 'mc'):
varname_mean = fitcfg.get('var_mean', 'mean')
Expand Down

0 comments on commit aa90ba9

Please sign in to comment.