Skip to content

Commit

Permalink
style: pre-commit fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pre-commit-ci[bot] committed Sep 30, 2024
1 parent a9ff472 commit 7c42472
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 35 deletions.
79 changes: 47 additions & 32 deletions src/HH4b/boosted/ValidateBDT.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,26 @@
from __future__ import annotations

import argparse
import os
import importlib
import logging
import sys
from pathlib import Path
import xgboost as xgb

import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.metrics import auc, roc_curve
import importlib

import HH4b.utils as utils
from HH4b import plotting, hh_vars
from HH4b import hh_vars
from HH4b.log_utils import log_config
from HH4b.utils import get_var_mapping

import logging
from HH4b.log_utils import log_config
log_config["root"]["level"] = "INFO"
logging.config.dictConfig(log_config)
logger = logging.getLogger("ValidateBDT")


def load_events(path_to_dir, year, jet_coll_pnet, jet_coll_mass, bdt_models):
logger.info(f"Load {year}")

Expand Down Expand Up @@ -52,18 +53,17 @@ def load_events(path_to_dir, year, jet_coll_pnet, jet_coll_mass, bdt_models):
year: {
"qcd": [
"QCD_HT-1000to1200",
#"QCD_HT-1200to1500",
#"QCD_HT-1500to2000",
#"QCD_HT-2000",
#"QCD_HT-400to600",
#"QCD_HT-600to800",
#"QCD_HT-800to1000",
# "QCD_HT-1200to1500",
# "QCD_HT-1500to2000",
# "QCD_HT-2000",
# "QCD_HT-400to600",
# "QCD_HT-600to800",
# "QCD_HT-800to1000",
],
"ttbar": [
"TTto4Q",
],
},

}
sample_dirs_sig = {
year: {
Expand Down Expand Up @@ -92,19 +92,19 @@ def load_events(path_to_dir, year, jet_coll_pnet, jet_coll_mass, bdt_models):
(f"{jet_collection}Msd", num_jets),
(f"{jet_collection}Eta", num_jets),
(f"{jet_collection}Phi", num_jets),
(f"{jet_collection}PNetPXbbLegacy", num_jets), # Legacy PNet
(f"{jet_collection}PNetPXbbLegacy", num_jets), # Legacy PNet
(f"{jet_collection}PNetPQCDbLegacy", num_jets),
(f"{jet_collection}PNetPQCDbbLegacy", num_jets),
(f"{jet_collection}PNetPQCD0HFLegacy", num_jets),
(f"{jet_collection}PNetMassLegacy", num_jets),
(f"{jet_collection}PNetTXbbLegacy", num_jets),
(f"{jet_collection}PNetTXbb", num_jets), # 103X PNet
(f"{jet_collection}PNetTXbb", num_jets), # 103X PNet
(f"{jet_collection}PNetMass", num_jets),
(f"{jet_collection}PNetQCD0HF", num_jets),
(f"{jet_collection}PNetQCD1HF", num_jets),
(f"{jet_collection}PNetQCD2HF", num_jets),
(f"{jet_collection}ParTmassVis", num_jets), # GloParT
(f"{jet_collection}ParTTXbb", num_jets),
(f"{jet_collection}ParTmassVis", num_jets), # GloParT
(f"{jet_collection}ParTTXbb", num_jets),
(f"{jet_collection}ParTPXbb", num_jets),
(f"{jet_collection}ParTPQCD0HF", num_jets),
(f"{jet_collection}ParTPQCD1HF", num_jets),
Expand All @@ -119,7 +119,7 @@ def load_events(path_to_dir, year, jet_coll_pnet, jet_coll_mass, bdt_models):
],
]

# dictionary that will contain all information (from all samples)
# dictionary that will contain all information (from all samples)
events_dict = {
# this function will load files (only the columns selected), apply filters and compute a weight per event
**utils.load_samples(
Expand Down Expand Up @@ -169,7 +169,9 @@ def apply_cuts(events_dict, txbb_str, mass_str):

def get_bdt(events_dict, bdt_model, bdt_model_name, bdt_config, jlabel=""):
bdt_model = xgb.XGBClassifier()
bdt_model.load_model(fname=f"../boosted/bdt_trainings_run3/{bdt_model_name}/trained_bdt.model")
bdt_model.load_model(
fname=f"../boosted/bdt_trainings_run3/{bdt_model_name}/trained_bdt.model"
)
make_bdt_dataframe = importlib.import_module(
f".{bdt_config}", package="HH4b.boosted.bdt_trainings_run3"
)
Expand All @@ -185,9 +187,7 @@ def get_bdt(events_dict, bdt_model, bdt_model_name, bdt_config, jlabel=""):
elif preds.shape[1] == 4: # multi-class BDT with ggF HH, VBF HH, QCD, ttbar classes
bg_tot = np.sum(preds[:, 2:], axis=1)
bdt_score = preds[:, 0] / (preds[:, 0] + bg_tot)
bdt_score_vbf = preds[:, 1] / (
preds[:, 1] + preds[:, 2] + preds[:, 3]
)
bdt_score_vbf = preds[:, 1] / (preds[:, 1] + preds[:, 2] + preds[:, 3])
return bdt_score, bdt_score_vbf

events_dict = apply_cuts(events_dict, txbb_str, mass_str)
Expand All @@ -198,9 +198,15 @@ def get_bdt(events_dict, bdt_model, bdt_model_name, bdt_config, jlabel=""):
bdt_config = bdt_models[bdt_model]["config"]
bdt_model_name = bdt_models[bdt_model]["model_name"]
for key in events_dict:
bdt_score, bdt_score_vbf = get_bdt(events_dict[key], bdt_model, bdt_model_name, bdt_config)
events_dict[key][f"bdtscore_{bdt_model}"] = bdt_score if bdt_score is not None else np.ones(events_dict[key]["weight"])
events_dict[key][f"bdtscoreVBF_{bdt_model}"] = bdt_score if bdt_score is not None else np.ones(events_dict[key]["weight"])
bdt_score, bdt_score_vbf = get_bdt(
events_dict[key], bdt_model, bdt_model_name, bdt_config
)
events_dict[key][f"bdtscore_{bdt_model}"] = (
bdt_score if bdt_score is not None else np.ones(events_dict[key]["weight"])
)
events_dict[key][f"bdtscoreVBF_{bdt_model}"] = (
bdt_score if bdt_score is not None else np.ones(events_dict[key]["weight"])
)
bdt_scores.extend([f"bdtscore_{bdt_model}", f"bdtscoreVBF_{bdt_model}"])

return {key: events_dict[key][bdt_scores] for key in events_dict}
Expand Down Expand Up @@ -260,7 +266,7 @@ def get_roc(
def main(args):
out_dir = Path(f"./bdt_comparison/{args.out_dir}/")
out_dir.mkdir(exist_ok=True, parents=True)

bdt_models = {
"v5_PNetLegacy": {
"config": "v5",
Expand All @@ -269,15 +275,24 @@ def main(args):
"v5_ParT": {
"config": "v5_glopartv2",
"model_name": "24Sep27_v5_glopartv2",
}
},
}

bdt_dict = {year: load_events(
args.data_path, year, jet_coll_pnet="ParTTXbb", jet_coll_mass="ParTmassVis", bdt_models=bdt_models
) for year in args.year}
bdt_dict = {
year: load_events(
args.data_path,
year,
jet_coll_pnet="ParTTXbb",
jet_coll_mass="ParTmassVis",
bdt_models=bdt_models,
)
for year in args.year
}
processes = ["qcd", "ttbar", "hh4b"]
bdt_dict_combined = {key: pd.concat([bdt_dict[year][key] for year in bdt_dict]) for key in processes}

bdt_dict_combined = {
key: pd.concat([bdt_dict[year][key] for year in bdt_dict]) for key in processes
}

print(bdt_dict_combined)

rocs = {
Expand Down
4 changes: 1 addition & 3 deletions src/HH4b/boosted/bdt_trainings_run3/v5.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ def bdt_dataframe(events, key_map=lambda x: x):
key_map("H1Xbb"): events[key_map("bbFatJetPNetPXbbLegacy")].to_numpy()[:, 0],
key_map("H1QCDb"): events[key_map("bbFatJetPNetPQCDbLegacy")].to_numpy()[:, 0],
key_map("H1QCDbb"): events[key_map("bbFatJetPNetPQCDbbLegacy")].to_numpy()[:, 0],
key_map("H1QCDothers"): events[key_map("bbFatJetPNetPQCD0HFLegacy")].to_numpy()[
:, 0
],
key_map("H1QCDothers"): events[key_map("bbFatJetPNetPQCD0HFLegacy")].to_numpy()[:, 0],
# ratios
key_map("H1Pt_HHmass"): h1.pt / hh.mass,
key_map("H2Pt_HHmass"): h2.pt / hh.mass,
Expand Down

0 comments on commit 7c42472

Please sign in to comment.