diff --git a/src/lib.rs b/src/lib.rs index 8e5af543..b693030a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -155,7 +155,9 @@ fn manysearch>( ksize: u8, scaled: usize, output: Option

, + num_threads: usize, ) -> Result<()> { + // construct a MinHash template for loading. let max_hash = max_hash_for_scaled(scaled as u64); let template_mh = KmerMinHash::builder() @@ -199,7 +201,7 @@ fn manysearch>( eprintln!("Loaded {} sig paths to search.", search_sigs_paths.len()); // set up a multi-producer, single-consumer channel. - let (send, recv) = std::sync::mpsc::sync_channel(rayon::current_num_threads()); + let (send, recv) = std::sync::mpsc::sync_channel(num_threads); // & spawn a thread that is dedicated to printing to a buffered output let out: Box = match output { @@ -533,6 +535,7 @@ fn countergather + std::fmt::Debug + std::fmt::Display + Clone>( scaled: usize, gather_output: Option

, prefetch_output: Option

, + num_threads: usize, ) -> Result<()> { let max_hash = max_hash_for_scaled(scaled as u64); let template_mh = KmerMinHash::builder() @@ -615,6 +618,7 @@ fn multigather + std::fmt::Debug + Clone>( threshold_bp: usize, ksize: u8, scaled: usize, + num_threads: usize, ) -> Result<()> { let max_hash = max_hash_for_scaled(scaled as u64); let template_mh = KmerMinHash::builder() @@ -764,10 +768,11 @@ fn do_manysearch(querylist_path: String, threshold: f64, ksize: u8, scaled: usize, - output_path: String + output_path: String, + num_threads: usize, ) -> anyhow::Result { match manysearch(querylist_path, siglist_path, threshold, ksize, scaled, - Some(output_path)) { + Some(output_path), num_threads) { Ok(_) => Ok(0), Err(e) => { eprintln!("Error: {e}"); @@ -784,11 +789,13 @@ fn do_countergather(query_filename: String, scaled: usize, output_path_prefetch: Option, output_path_gather: Option, + num_threads: usize, ) -> anyhow::Result { match countergather(query_filename, siglist_path, threshold_bp, ksize, scaled, output_path_prefetch, - output_path_gather) { + output_path_gather, + num_threads) { Ok(_) => Ok(0), Err(e) => { eprintln!("Error: {e}"); @@ -802,10 +809,11 @@ fn do_multigather(query_filenames: String, siglist_path: String, threshold_bp: usize, ksize: u8, - scaled: usize + scaled: usize, + num_threads: usize, ) -> anyhow::Result { match multigather(query_filenames, siglist_path, threshold_bp, - ksize, scaled) { + ksize, scaled, num_threads) { Ok(_) => Ok(0), Err(e) => { eprintln!("Error: {e}"); @@ -815,8 +823,12 @@ fn do_multigather(query_filenames: String, } #[pyfunction] -fn get_num_threads() -> PyResult { - Ok(rayon::current_num_threads()) +fn set_global_thread_pool(num_threads: usize) -> PyResult<()> { + if let Ok(_) = std::panic::catch_unwind(|| rayon::ThreadPoolBuilder::new().num_threads(num_threads).build_global()) { + Ok(()) + } else { + Err(PyErr::new::("Could not set the number of threads. Global thread pool might already be initialized.")) + } } #[pymodule] @@ -824,6 +836,6 @@ fn pyo3_branchwater(_py: Python, m: &PyModule) -> PyResult<()> { m.add_function(wrap_pyfunction!(do_manysearch, m)?)?; m.add_function(wrap_pyfunction!(do_countergather, m)?)?; m.add_function(wrap_pyfunction!(do_multigather, m)?)?; - m.add_function(wrap_pyfunction!(get_num_threads, m)?)?; + m.add_function(wrap_pyfunction!(set_global_thread_pool, m)?)?; Ok(()) } diff --git a/src/python/pyo3_branchwater/__init__.py b/src/python/pyo3_branchwater/__init__.py index f37179d0..3ef7d84b 100755 --- a/src/python/pyo3_branchwater/__init__.py +++ b/src/python/pyo3_branchwater/__init__.py @@ -3,9 +3,20 @@ import argparse from sourmash.plugins import CommandLinePlugin from sourmash.logging import notify +import os from . import pyo3_branchwater + +def get_max_cores(): + if 'SLURM_CPUS_ON_NODE' in os.environ: + return int(os.environ['SLURM_CPUS_ON_NODE']) + elif 'SLURM_JOB_CPUS_PER_NODE' in os.environ: + return int(os.environ['SLURM_JOB_CPUS_PER_NODE'].split('x')[0]) # Assumes a simple format; won't handle complex scenarios + else: + return os.cpu_count() + + class Branchwater_Manysearch(CommandLinePlugin): command = 'manysearch' description = 'massively parallel sketch search' @@ -24,18 +35,31 @@ def __init__(self, p): help='k-mer size at which to select sketches') p.add_argument('-s', '--scaled', default=1000, type=int, help='scaled factor at which to do comparisons') + p.add_argument('-c', '--cores', default=0, type=int, + help='number of cores to use (default is all available)') def main(self, args): notify(f"ksize: {args.ksize} / scaled: {args.scaled} / threshold: {args.threshold}") - num_threads = pyo3_branchwater.get_num_threads() + + + avail_threads = get_max_cores() + num_threads = min(avail_threads, args.cores) if args.cores else avail_threads + if args.cores > avail_threads: + notify(f"warning: only {avail_threads} threads available, using {avail_threads} instead of {args.cores}") + + + pyo3_branchwater.set_global_thread_pool(args.cores) + notify(f"searching all sketches in '{args.query_paths}' against '{args.against_paths}' using {num_threads} threads") + super().main(args) status = pyo3_branchwater.do_manysearch(args.query_paths, args.against_paths, args.threshold, args.ksize, args.scaled, - args.output) + args.output, + num_threads) if status == 0: notify(f"...manysearch is done! results in '{args.output}'") return status @@ -59,10 +83,20 @@ def __init__(self, p): help='k-mer size at which to do comparisons (default: 31)') p.add_argument('-s', '--scaled', default=1000, type=int, help='scaled factor at which to do comparisons (default: 1000)') + p.add_argument('-c', '--cores', default=0, type=int, + help='number of cores to use (default is all available)') + def main(self, args): notify(f"ksize: {args.ksize} / scaled: {args.scaled} / threshold bp: {args.threshold_bp}") - num_threads = pyo3_branchwater.get_num_threads() + + avail_threads = get_max_cores() + num_threads = min(avail_threads, args.cores) if args.cores else avail_threads + if args.cores > avail_threads: + notify(f"warning: only {avail_threads} threads available, using {avail_threads} instead of {args.cores}") + + pyo3_branchwater.set_global_thread_pool(args.cores) + notify(f"gathering all sketches in '{args.query_sig}' against '{args.against_paths}' using {num_threads} threads") super().main(args) status = pyo3_branchwater.do_countergather(args.query_sig, @@ -71,7 +105,8 @@ def main(self, args): args.ksize, args.scaled, args.output_gather, - args.output_prefetch) + args.output_prefetch, + num_threads) if status == 0: notify(f"...fastgather is done! gather results in '{args.output_gather}'") if args.output_prefetch: @@ -93,17 +128,29 @@ def __init__(self, p): help='k-mer size at which to do comparisons (default: 31)') p.add_argument('-s', '--scaled', default=1000, type=int, help='scaled factor at which to do comparisons (default: 1000)') + p.add_argument('-c', '--cores', default=0, type=int, + help='number of cores to use (default is all available)') + def main(self, args): notify(f"ksize: {args.ksize} / scaled: {args.scaled} / threshold bp: {args.threshold_bp}") - num_threads = pyo3_branchwater.get_num_threads() + + avail_threads = get_max_cores() + num_threads = min(avail_threads, args.cores) if args.cores else avail_threads + if args.cores > avail_threads: + notify(f"warning: only {avail_threads} threads available, using {avail_threads} instead of {args.cores}") + + pyo3_branchwater.set_global_thread_pool(args.cores) + + notify(f"gathering all sketches in '{args.query_paths}' against '{args.against_paths}' using {num_threads} threads") super().main(args) status = pyo3_branchwater.do_multigather(args.query_paths, args.against_paths, int(args.threshold_bp), args.ksize, - args.scaled) + args.scaled, + num_threads) if status == 0: notify(f"...fastmultigather is done!") return status