Skip to content

Commit

Permalink
Improve packaging, logging, style (#838)
Browse files Browse the repository at this point in the history
* Update logger and add quiet mode

* Improve packaging and some cleanup
  • Loading branch information
qgp authored Dec 1, 2023
1 parent f1dc0be commit fb38e77
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 19 deletions.
18 changes: 18 additions & 0 deletions machine_learning_hep/__main__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#############################################################################
## © Copyright CERN 2023. All rights not expressly granted are reserved. ##
## ##
## This program is free software: you can redistribute it and/or modify it ##
## under the terms of the GNU General Public License as published by the ##
## Free Software Foundation, either version 3 of the License, or (at your ##
## option) any later version. This program is distributed in the hope that ##
## it will be useful, but WITHOUT ANY WARRANTY; without even the implied ##
## warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. ##
## See the GNU General Public License for more details. ##
## You should have received a copy of the GNU General Public License ##
## along with this program. if not, see <https://www.gnu.org/licenses/>. ##
#############################################################################

import sys
from machine_learning_hep.steer_analysis import main

sys.exit(main())
16 changes: 7 additions & 9 deletions machine_learning_hep/logger.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#############################################################################
## © Copyright CERN 2018. All rights not expressly granted are reserved. ##
## © Copyright CERN 2023. All rights not expressly granted are reserved. ##
## Author: [email protected] ##
## This program is free software: you can redistribute it and/or modify it ##
## under the terms of the GNU General Public License as published by the ##
Expand Down Expand Up @@ -55,8 +55,9 @@ class MLLoggerFormatter(logging.Formatter):
reset = '\x1b[0m'

# Define default format string
def __init__(self, fmt='%(levelname)s in %(pathname)s:%(lineno)d:\n ↳ %(message)s',
def __init__(self, fmt=None,
datefmt=None, style='%', color=False):
fmt = fmt or '%(levelname)s in %(pathname)s:%(lineno)d:\n ↳ %(message)s'
logging.Formatter.__init__(self, fmt, datefmt, style)
self.color = color

Expand Down Expand Up @@ -86,7 +87,7 @@ def format(self, record):
return logging.Formatter.format(self, cached_record)


def configure_logger(debug, logfile=None):
def configure_logger(debug, logfile=None, quiet=False):
"""
Basic configuration adding a custom formatted StreamHandler and turning on
debug info if requested.
Expand All @@ -95,14 +96,11 @@ def configure_logger(debug, logfile=None):
if logger.hasHandlers():
return

# Turn on debug info only on request
if debug:
logger.setLevel(logging.DEBUG)
else:
logger.setLevel(logging.INFO)
logger.setLevel(logging.DEBUG if debug else logging.INFO)

sh = logging.StreamHandler()
formatter = MLLoggerFormatter(color=lambda : getattr(sh.stream, 'isatty', None))
formatter = MLLoggerFormatter(color=lambda : getattr(sh.stream, 'isatty', None),
fmt = '%(levelname)s ➞ %(message)s' if quiet else None)

sh.setFormatter(formatter)
logger.addHandler(sh)
Expand Down
12 changes: 5 additions & 7 deletions machine_learning_hep/steer_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,6 @@
import yaml
from pkg_resources import resource_stream

# To set batch mode immediately
from ROOT import gROOT # pylint: disable=import-error, no-name-in-module

from machine_learning_hep.multiprocesser import MultiProcesser
from machine_learning_hep.processer import Processer
from machine_learning_hep.processerdhadrons import ProcesserDhadrons
Expand Down Expand Up @@ -55,9 +52,6 @@
def do_entire_analysis(data_config: dict, data_param: dict, data_param_overwrite: dict, # pylint: disable=too-many-locals, too-many-statements, too-many-branches
data_model: dict, run_param: dict, clean: bool):

# Disable any graphical stuff. No TCanvases opened and shown by default
gROOT.SetBatch(True)

logger = get_logger()
logger.info("Do analysis chain")

Expand Down Expand Up @@ -491,6 +485,7 @@ def main():

parser = argparse.ArgumentParser()
parser.add_argument("--debug", action="store_true", help="activate debug log level")
parser.add_argument("--quiet", '-q', action="store_true", help="quiet logging")
parser.add_argument("--log-file", dest="log_file", help="file to print the log to")
parser.add_argument("--run-config", "-r", dest="run_config",
help="the run configuration to be used")
Expand All @@ -509,7 +504,7 @@ def main():

args = parser.parse_args()

configure_logger(args.debug, args.log_file)
configure_logger(args.debug, args.log_file, args.quiet)

# Extract which database and run config to be used
pkg_data = "machine_learning_hep.data"
Expand All @@ -526,3 +521,6 @@ def main():
# Run the chain
do_entire_analysis(run_config, db_analysis, db_analysis_overwrite, db_ml_models, db_run_list,
args.clean)

if __name__ == '__main__':
main()
7 changes: 4 additions & 3 deletions machine_learning_hep/templates_xgboost.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#############################################################################
## © Copyright CERN 2018. All rights not expressly granted are reserved. ##
## © Copyright CERN 2023. All rights not expressly granted are reserved. ##
## Author: [email protected] ##
## This program is free software: you can redistribute it and/or modify it ##
## under the terms of the GNU General Public License as published by the ##
Expand All @@ -25,7 +25,7 @@

def xgboost_classifier(model_config): # pylint: disable=W0613
return XGBClassifier(verbosity=1,
n_gpus=0,
# n_gpus=0,
**model_config)


Expand Down Expand Up @@ -55,7 +55,8 @@ def yield_model_(self, model_config, space):

def save_model_(self, model, out_dir):
out_filename = join(out_dir, "xgboost_classifier.sav")
pickle.dump(model, open(out_filename, 'wb'), protocol=4)
with open(out_filename, 'wb') as outfile:
pickle.dump(model, outfile, protocol=4)
out_filename = join(out_dir, "xgboost_classifier.model")
model.save_model(out_filename)

Expand Down

0 comments on commit fb38e77

Please sign in to comment.