diff --git a/src/lib.rs b/src/lib.rs index 28a0314a..8e5af543 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -208,10 +208,10 @@ fn manysearch>( }; let thrd = std::thread::spawn(move || { let mut writer = BufWriter::new(out); - writeln!(&mut writer, "query,query_md5,match,match_md5,containment").unwrap(); - for (query, query_md5, m, m_md5, overlap) in recv.into_iter() { - writeln!(&mut writer, "\"{}\",{},\"{}\",{},{}", - query, query_md5, m, m_md5, overlap).ok(); + writeln!(&mut writer, "query_name,query_md5,match_name,match_md5,containment,jaccard,intersect_hashes").unwrap(); + for (query, query_md5, m, m_md5, cont, jaccard, overlap) in recv.into_iter() { + writeln!(&mut writer, "\"{}\",{},\"{}\",{},{},{},{}", + query, query_md5, m, m_md5, cont, jaccard, overlap).ok(); } }); @@ -241,14 +241,21 @@ fn manysearch>( // search for matches & save containment. for q in queries.iter() { let overlap = q.minhash.count_common(&search_sm.minhash, false).unwrap() as f64; - let size = q.minhash.size() as f64; + let query_size = q.minhash.size() as f64; - let containment = overlap / size; + let mut merged = q.minhash.clone(); + merged.merge(&search_sm.minhash).ok(); + let total_size = merged.size() as f64; + + let containment = overlap / query_size; + let jaccard = overlap / total_size; if containment > threshold { results.push((q.name.clone(), q.md5sum.clone(), search_sm.name.clone(), search_sm.md5sum.clone(), + containment, + jaccard, overlap)) } } diff --git a/src/python/tests/test-data/1.combined.sig.gz b/src/python/tests/test-data/1.combined.sig.gz index 0d039fa7..97d0e1cd 100644 Binary files a/src/python/tests/test-data/1.combined.sig.gz and b/src/python/tests/test-data/1.combined.sig.gz differ diff --git a/src/python/tests/test_search.py b/src/python/tests/test_search.py index 6526948a..17fe36e7 100644 --- a/src/python/tests/test_search.py +++ b/src/python/tests/test_search.py @@ -45,6 +45,35 @@ def test_simple(runtmp): df = pandas.read_csv(output) assert len(df) == 5 + dd = df.to_dict(orient='index') + print(dd) + + for idx, row in dd.items(): + # identical? + if row['query_md5'] == row['match_md5']: + assert row['match_name'] == row['query_name'] + assert float(row['containment'] == 1.0) + assert float(row['jaccard'] == 1.0) + else: + # confirm hand-checked numbers + q = row['query_name'].split()[0] + m = row['match_name'].split()[0] + jaccard = float(row['jaccard']) + cont = float(row['containment']) + intersect_hashes = int(row['intersect_hashes']) + + jaccard = round(jaccard, 4) + cont = round(cont, 4) + print(q, m, f"{jaccard:.04}", f"{cont:.04}") + + if q == 'NC_011665.1' and m == 'NC_009661.1': + assert jaccard == 0.3207 + assert cont == 0.4828 + + if q == 'NC_009661.1' and m == 'NC_011665.1': + assert jaccard == 0.3207 + assert cont == 0.4885 + def test_simple_threshold(runtmp): # test with a simple threshold => only 3 results