Skip to content

Commit

Permalink
add -c --cores
Browse files Browse the repository at this point in the history
  • Loading branch information
mr-eyes committed Aug 25, 2023
1 parent ae0b26a commit 0c165b5
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 15 deletions.
30 changes: 21 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,9 @@ fn manysearch<P: AsRef<Path>>(
ksize: u8,
scaled: usize,
output: Option<P>,
num_threads: usize,
) -> Result<()> {

// construct a MinHash template for loading.
let max_hash = max_hash_for_scaled(scaled as u64);
let template_mh = KmerMinHash::builder()
Expand Down Expand Up @@ -199,7 +201,7 @@ fn manysearch<P: AsRef<Path>>(
eprintln!("Loaded {} sig paths to search.", search_sigs_paths.len());

// set up a multi-producer, single-consumer channel.
let (send, recv) = std::sync::mpsc::sync_channel(rayon::current_num_threads());
let (send, recv) = std::sync::mpsc::sync_channel(num_threads);

// & spawn a thread that is dedicated to printing to a buffered output
let out: Box<dyn Write + Send> = match output {
Expand Down Expand Up @@ -533,6 +535,7 @@ fn countergather<P: AsRef<Path> + std::fmt::Debug + std::fmt::Display + Clone>(
scaled: usize,
gather_output: Option<P>,
prefetch_output: Option<P>,
num_threads: usize,
) -> Result<()> {
let max_hash = max_hash_for_scaled(scaled as u64);
let template_mh = KmerMinHash::builder()
Expand Down Expand Up @@ -615,6 +618,7 @@ fn multigather<P: AsRef<Path> + std::fmt::Debug + Clone>(
threshold_bp: usize,
ksize: u8,
scaled: usize,
num_threads: usize,
) -> Result<()> {
let max_hash = max_hash_for_scaled(scaled as u64);
let template_mh = KmerMinHash::builder()
Expand Down Expand Up @@ -764,10 +768,11 @@ fn do_manysearch(querylist_path: String,
threshold: f64,
ksize: u8,
scaled: usize,
output_path: String
output_path: String,
num_threads: usize,
) -> anyhow::Result<u8> {
match manysearch(querylist_path, siglist_path, threshold, ksize, scaled,
Some(output_path)) {
Some(output_path), num_threads) {
Ok(_) => Ok(0),
Err(e) => {
eprintln!("Error: {e}");
Expand All @@ -784,11 +789,13 @@ fn do_countergather(query_filename: String,
scaled: usize,
output_path_prefetch: Option<String>,
output_path_gather: Option<String>,
num_threads: usize,
) -> anyhow::Result<u8> {
match countergather(query_filename, siglist_path, threshold_bp,
ksize, scaled,
output_path_prefetch,
output_path_gather) {
output_path_gather,
num_threads) {
Ok(_) => Ok(0),
Err(e) => {
eprintln!("Error: {e}");
Expand All @@ -802,10 +809,11 @@ fn do_multigather(query_filenames: String,
siglist_path: String,
threshold_bp: usize,
ksize: u8,
scaled: usize
scaled: usize,
num_threads: usize,
) -> anyhow::Result<u8> {
match multigather(query_filenames, siglist_path, threshold_bp,
ksize, scaled) {
ksize, scaled, num_threads) {
Ok(_) => Ok(0),
Err(e) => {
eprintln!("Error: {e}");
Expand All @@ -815,15 +823,19 @@ fn do_multigather(query_filenames: String,
}

#[pyfunction]
fn get_num_threads() -> PyResult<usize> {
Ok(rayon::current_num_threads())
fn set_global_thread_pool(num_threads: usize) -> PyResult<()> {
if let Ok(_) = std::panic::catch_unwind(|| rayon::ThreadPoolBuilder::new().num_threads(num_threads).build_global()) {
Ok(())
} else {
Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>("Could not set the number of threads. Global thread pool might already be initialized."))
}
}

#[pymodule]
fn pyo3_branchwater(_py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(do_manysearch, m)?)?;
m.add_function(wrap_pyfunction!(do_countergather, m)?)?;
m.add_function(wrap_pyfunction!(do_multigather, m)?)?;
m.add_function(wrap_pyfunction!(get_num_threads, m)?)?;
m.add_function(wrap_pyfunction!(set_global_thread_pool, m)?)?;
Ok(())
}
59 changes: 53 additions & 6 deletions src/python/pyo3_branchwater/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,20 @@
import argparse
from sourmash.plugins import CommandLinePlugin
from sourmash.logging import notify
import os

from . import pyo3_branchwater


def get_max_cores():
if 'SLURM_CPUS_ON_NODE' in os.environ:
return int(os.environ['SLURM_CPUS_ON_NODE'])
elif 'SLURM_JOB_CPUS_PER_NODE' in os.environ:
return int(os.environ['SLURM_JOB_CPUS_PER_NODE'].split('x')[0]) # Assumes a simple format; won't handle complex scenarios
else:
return os.cpu_count()


class Branchwater_Manysearch(CommandLinePlugin):
command = 'manysearch'
description = 'massively parallel sketch search'
Expand All @@ -24,18 +35,31 @@ def __init__(self, p):
help='k-mer size at which to select sketches')
p.add_argument('-s', '--scaled', default=1000, type=int,
help='scaled factor at which to do comparisons')
p.add_argument('-c', '--cores', default=0, type=int,
help='number of cores to use (default is all available)')

def main(self, args):
notify(f"ksize: {args.ksize} / scaled: {args.scaled} / threshold: {args.threshold}")
num_threads = pyo3_branchwater.get_num_threads()


avail_threads = get_max_cores()
num_threads = min(avail_threads, args.cores) if args.cores else avail_threads
if args.cores > avail_threads:
notify(f"warning: only {avail_threads} threads available, using {avail_threads} instead of {args.cores}")


pyo3_branchwater.set_global_thread_pool(args.cores)

notify(f"searching all sketches in '{args.query_paths}' against '{args.against_paths}' using {num_threads} threads")

super().main(args)
status = pyo3_branchwater.do_manysearch(args.query_paths,
args.against_paths,
args.threshold,
args.ksize,
args.scaled,
args.output)
args.output,
num_threads)
if status == 0:
notify(f"...manysearch is done! results in '{args.output}'")
return status
Expand All @@ -59,10 +83,20 @@ def __init__(self, p):
help='k-mer size at which to do comparisons (default: 31)')
p.add_argument('-s', '--scaled', default=1000, type=int,
help='scaled factor at which to do comparisons (default: 1000)')
p.add_argument('-c', '--cores', default=0, type=int,
help='number of cores to use (default is all available)')


def main(self, args):
notify(f"ksize: {args.ksize} / scaled: {args.scaled} / threshold bp: {args.threshold_bp}")
num_threads = pyo3_branchwater.get_num_threads()

avail_threads = get_max_cores()
num_threads = min(avail_threads, args.cores) if args.cores else avail_threads
if args.cores > avail_threads:
notify(f"warning: only {avail_threads} threads available, using {avail_threads} instead of {args.cores}")

pyo3_branchwater.set_global_thread_pool(args.cores)

notify(f"gathering all sketches in '{args.query_sig}' against '{args.against_paths}' using {num_threads} threads")
super().main(args)
status = pyo3_branchwater.do_countergather(args.query_sig,
Expand All @@ -71,7 +105,8 @@ def main(self, args):
args.ksize,
args.scaled,
args.output_gather,
args.output_prefetch)
args.output_prefetch,
num_threads)
if status == 0:
notify(f"...fastgather is done! gather results in '{args.output_gather}'")
if args.output_prefetch:
Expand All @@ -93,17 +128,29 @@ def __init__(self, p):
help='k-mer size at which to do comparisons (default: 31)')
p.add_argument('-s', '--scaled', default=1000, type=int,
help='scaled factor at which to do comparisons (default: 1000)')
p.add_argument('-c', '--cores', default=0, type=int,
help='number of cores to use (default is all available)')


def main(self, args):
notify(f"ksize: {args.ksize} / scaled: {args.scaled} / threshold bp: {args.threshold_bp}")
num_threads = pyo3_branchwater.get_num_threads()

avail_threads = get_max_cores()
num_threads = min(avail_threads, args.cores) if args.cores else avail_threads
if args.cores > avail_threads:
notify(f"warning: only {avail_threads} threads available, using {avail_threads} instead of {args.cores}")

pyo3_branchwater.set_global_thread_pool(args.cores)


notify(f"gathering all sketches in '{args.query_paths}' against '{args.against_paths}' using {num_threads} threads")
super().main(args)
status = pyo3_branchwater.do_multigather(args.query_paths,
args.against_paths,
int(args.threshold_bp),
args.ksize,
args.scaled)
args.scaled,
num_threads)
if status == 0:
notify(f"...fastmultigather is done!")
return status

0 comments on commit 0c165b5

Please sign in to comment.