diff --git a/Cargo.toml b/Cargo.toml index c4f37f69..3acef4ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,8 @@ log = "0.4.14" env_logger = "0.9.0" simple-error = "0.3.0" anyhow = "1.0.75" +zip = "0.6" +tempfile = "3.8" [dev-dependencies] assert_cmd = "2.0.4" diff --git a/src/lib.rs b/src/lib.rs index ca04ffa5..199b47e7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,8 +8,12 @@ use std::fs::File; use std::io::{BufRead, BufReader, BufWriter, Write}; use std::path::{Path, PathBuf}; +use zip::read::ZipArchive; +use tempfile::tempdir; + use std::sync::atomic; use std::sync::atomic::AtomicUsize; +use std::io::Read; use std::collections::BinaryHeap; @@ -17,6 +21,7 @@ use std::cmp::{PartialOrd, Ordering}; use anyhow::{Result, anyhow}; + #[macro_use] extern crate simple_error; @@ -759,6 +764,30 @@ fn build_template(ksize: u8, scaled: usize) -> Sketch { Sketch::MinHash(template_mh) } +fn read_signatures_from_zip>( + zip_path: P, +) -> Result<(Vec, tempfile::TempDir), Box> { + let mut signature_paths = Vec::new(); + let temp_dir = tempdir()?; + let zip_file = File::open(&zip_path)?; + let mut zip_archive = ZipArchive::new(zip_file)?; + + for i in 0..zip_archive.len() { + let mut file = zip_archive.by_index(i)?; + let mut sig = Vec::new(); + file.read_to_end(&mut sig)?; + + let file_name = Path::new(file.name()).file_name().unwrap().to_str().unwrap(); + if file_name.ends_with(".sig") || file_name.ends_with(".sig.gz") { + let new_path = temp_dir.path().join(file_name); + let mut new_file = File::create(&new_path)?; + new_file.write_all(&sig)?; + signature_paths.push(new_path); + } + } + println!("wrote {} signatures to temp dir", signature_paths.len()); + Ok((signature_paths, temp_dir)) +} fn index>( siglist: P, @@ -766,11 +795,22 @@ fn index>( output: P, save_paths: bool, colors: bool, - ) -> Result<(), Box> { - +) -> Result<(), Box> { + let mut temp_dir = None; info!("Loading siglist"); - let index_sigs = load_sketchlist_filenames(&siglist)?; + + let index_sigs: Vec; + + if siglist.as_ref().extension().map(|ext| ext == "zip").unwrap_or(false) { + let (paths, tempdir) = read_signatures_from_zip(&siglist)?; + temp_dir = Some(tempdir); + index_sigs = paths; + } else { + index_sigs = load_sketchlist_filenames(&siglist)?; + } + info!("Loaded {} sig paths in siglist", index_sigs.len()); + println!("Loaded {} sig paths in siglist", index_sigs.len()); // Create or open the RevIndex database with the provided output path and colors flag let db = RevIndex::create(output.as_ref(), colors); @@ -778,10 +818,13 @@ fn index>( // Index the signatures using the loaded template, threshold, and save_paths option db.index(index_sigs, &template, 0.0, save_paths); + if let Some(temp_dir) = temp_dir { + temp_dir.close()?; + } + Ok(()) } - fn check>(index: P, quick: bool) -> Result<(), Box> { info!("Opening DB"); let db = RevIndex::open(index.as_ref(), true); diff --git a/src/python/pyo3_branchwater/__init__.py b/src/python/pyo3_branchwater/__init__.py index 3a96a4d6..f36c2577 100755 --- a/src/python/pyo3_branchwater/__init__.py +++ b/src/python/pyo3_branchwater/__init__.py @@ -173,7 +173,8 @@ def main(self, args): num_threads = set_thread_pool(args.cores) - notify(f"indexing all sketches in '{args.siglist}' using {num_threads} threads") + notify(f"indexing all sketches in '{args.siglist}'") + super().main(args) status = pyo3_branchwater.do_index(args.siglist, args.ksize, @@ -221,13 +222,12 @@ def main(self, args): # help='scaled factor at which to do comparisons') # p.add_argument('--save-paths', action='store_true', # help='save paths to signatures into index. Default: save full sig into index') - # p.add_argument('-c', '--cores', default=0, type=int, - # help='number of cores to use (default is all available)') +# p.add_argument('-c', '--cores', default=0, type=int, +# help='number of cores to use (default is all available)') # def main(self, args): # notify(f"ksize: {args.ksize} / scaled: {args.scaled} / threshold: {args.threshold}") # num_threads = set_thread_pool(args.cores) - # notify(f"updating index with all sketches in '{args.siglist}'") # super().main(args) # status = pyo3_branchwater.do_update(args.siglist,