Skip to content

Commit

Permalink
Merge pull request #39 from ratschlab/master
Browse files Browse the repository at this point in the history
Support larger alphabets and k via generic kmer_t
  • Loading branch information
jermp authored Aug 26, 2024
2 parents 9158671 + e32e502 commit 35492d8
Show file tree
Hide file tree
Showing 35 changed files with 1,049 additions and 879 deletions.
8 changes: 0 additions & 8 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -64,18 +64,10 @@ set(Z_LIB_SOURCES
include/gz/zip_stream.cpp
)

set(SSHASH_SOURCES
include/dictionary.cpp
include/info.cpp
include/dump.cpp
include/statistics.cpp
include/builder/build.cpp
)

# Create a static lib
add_library(sshash_static STATIC
${Z_LIB_SOURCES}
${SSHASH_SOURCES}
)

add_executable(sshash src/sshash.cpp)
Expand Down
108 changes: 36 additions & 72 deletions include/bit_vector_iterator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

namespace sshash {

template <class kmer_t>
struct bit_vector_iterator {
bit_vector_iterator() : m_bv(nullptr) {}

Expand All @@ -17,80 +18,61 @@ struct bit_vector_iterator {
}

inline kmer_t read(uint64_t l) {
assert(l <= constants::uint_kmer_bits);
assert(l <= kmer_t::uint_kmer_bits);
if (m_avail < l) fill_buf();
kmer_t val = 0;
if (l != constants::uint_kmer_bits) {
val = m_buf & ((kmer_t(1) << l) - 1);
} else {
val = m_buf;
}
kmer_t val = m_buf;
val.take(l);
return val;
}

inline kmer_t read_reverse(uint64_t l) {
assert(l <= constants::uint_kmer_bits);
assert(l <= kmer_t::uint_kmer_bits);
if (m_avail < l) fill_buf_reverse();
kmer_t val = 0;
if (l != constants::uint_kmer_bits) {
uint64_t shift = (l >= 64) ? (constants::uint_kmer_bits - l) : 64;
val = m_buf >> shift;
} else {
val = m_buf;
}
kmer_t val = m_buf;
val.drop(kmer_t::uint_kmer_bits - l);
return val;
}

inline void eat(uint64_t l) {
assert(l <= constants::uint_kmer_bits);
assert(l <= kmer_t::uint_kmer_bits);
if (m_avail < l) fill_buf();
if (l != constants::uint_kmer_bits) m_buf >>= l;
m_buf.drop(l);
m_avail -= l;
m_pos += l;
}

inline void eat_reverse(uint64_t l) {
assert(l <= constants::uint_kmer_bits);
assert(l <= kmer_t::uint_kmer_bits);
if (m_avail < l) fill_buf_reverse();
if (l != constants::uint_kmer_bits) m_buf <<= l;
m_buf.pad(l);
m_avail -= l;
m_pos -= l;
}

inline kmer_t read_and_advance_by_two(uint64_t l) {
assert(l <= constants::uint_kmer_bits);
inline kmer_t read_and_advance_by_char(uint64_t l) {
assert(l <= kmer_t::uint_kmer_bits);
if (m_avail < l) fill_buf();
kmer_t val = 0;
if (l != constants::uint_kmer_bits) {
val = m_buf & ((kmer_t(1) << l) - 1);
m_buf >>= 2;
} else {
val = m_buf;
}
m_avail -= 2;
m_pos += 2;
kmer_t val = m_buf;
val.take(l);
m_buf.drop_char();
m_avail -= kmer_t::bits_per_char;
m_pos += kmer_t::bits_per_char;
return val;
}

inline kmer_t get_next_two_bits() {
if (m_avail < 2) fill_buf();
kmer_t val = m_buf & 3;
m_buf >>= 2;
m_avail -= 2;
m_pos += 2;
return val;
inline uint64_t get_next_char() {
if (m_avail < kmer_t::bits_per_char) fill_buf();
m_avail -= kmer_t::bits_per_char;
m_pos += kmer_t::bits_per_char;
return m_buf.pop_char();
}

inline kmer_t take(uint64_t l) {
assert(l <= constants::uint_kmer_bits);
assert(l <= kmer_t::uint_kmer_bits);
if (m_avail < l) fill_buf();
kmer_t val = 0;
if (l != constants::uint_kmer_bits) {
val = m_buf & ((kmer_t(1) << l) - 1);
m_buf >>= l;
} else {
val = m_buf;
}
kmer_t val = m_buf;
val.take(l);
m_buf.drop(l);
m_avail -= l;
m_pos += l;
return val;
Expand All @@ -100,38 +82,20 @@ struct bit_vector_iterator {

private:
inline void fill_buf() {
if constexpr (constants::uint_kmer_bits == 64) {
m_buf = m_bv->get_word64(m_pos);
} else {
assert(constants::uint_kmer_bits == 128);
m_buf = static_cast<kmer_t>(m_bv->get_word64(m_pos));
m_buf += static_cast<kmer_t>(m_bv->get_word64(m_pos + 64)) << 64;
static_assert(kmer_t::uint_kmer_bits % 64 == 0);
for (int i = kmer_t::uint_kmer_bits - 64; i >= 0; i -= 64) {
if (m_pos + i < m_bv->size()) { m_buf.append64(m_bv->get_word64(m_pos + i)); }
}
m_avail = constants::uint_kmer_bits;
m_avail = kmer_t::uint_kmer_bits;
}

inline void fill_buf_reverse() {
if constexpr (constants::uint_kmer_bits == 64) {
if (m_pos < 64) {
m_buf = m_bv->get_word64(0);
m_avail = m_pos;
m_buf <<= (64 - m_pos);
return;
}
m_buf = m_bv->get_word64(m_pos - 64);
} else {
assert(constants::uint_kmer_bits == 128);
if (m_pos < 128) {
m_buf = static_cast<kmer_t>(m_bv->get_word64(0)) << 64;
m_buf += static_cast<kmer_t>(m_bv->get_word64(64));
m_avail = m_pos;
m_buf <<= (128 - m_pos);
return;
}
m_buf = static_cast<kmer_t>(m_bv->get_word64(m_pos - 128)) << 64;
m_buf += static_cast<kmer_t>(m_bv->get_word64(m_pos - 64));
static_assert(kmer_t::uint_kmer_bits % 64 == 0);
for (int i = kmer_t::uint_kmer_bits; i > 0; i -= 64) {
m_buf.append64(m_bv->get_word64(std::max<uint64_t>(m_pos, kmer_t::uint_kmer_bits) - i));
}
m_avail = constants::uint_kmer_bits;
m_avail = std::min<uint64_t>(m_pos, kmer_t::uint_kmer_bits);
m_buf.pad(kmer_t::uint_kmer_bits - m_avail);
}

pthash::bit_vector const* m_bv;
Expand Down
61 changes: 61 additions & 0 deletions include/bitpack.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#pragma once

namespace sshash {
// full binary tree of given height
// with Int type in its leafs
template <typename Int, uint16_t height>
struct bitpack {
static_assert(height > 0);
using halfpack = std::conditional_t<height == 1, Int, bitpack<Int, height - 1>>;
static constexpr uint16_t hsize = 8 * sizeof(halfpack);
halfpack a, b;

bitpack() {}
bitpack(uint64_t x) : a(x), b(0) {}
bitpack(halfpack a, halfpack b) : a(a), b(b) {}
explicit operator uint64_t() const { return (uint64_t)a; }

bool operator==(bitpack const& t) const { return std::pair{a, b} == std::pair{t.a, t.b}; }
bool operator!=(bitpack const& t) const { return std::pair{a, b} != std::pair{t.a, t.b}; }
bool operator<(bitpack const& t) const { return std::pair{a, b} < std::pair{t.a, t.b}; }

// shift in [0, size)
bitpack& operator>>=(uint16_t shift) {
if (shift < hsize) {
a = (a >> shift) | (b << (hsize - shift));
b >>= shift;
} else {
a = b >> (shift - hsize);
b = 0;
}
return *this;
}
bitpack& operator<<=(uint16_t shift) {
if (shift < hsize) {
b = (b << shift) | (a >> (hsize - shift));
a <<= shift;
} else {
b = a << (shift - hsize);
a = 0;
}
return *this;
}
bitpack operator<<(uint16_t shift) const { return bitpack(*this) <<= shift; }
bitpack operator>>(uint16_t shift) const { return bitpack(*this) >>= shift; }

bitpack& operator|=(bitpack const& t) {
a |= t.a;
b |= t.b;
return *this;
}
bitpack& operator&=(bitpack const& t) {
a &= t.a;
b &= t.b;
return *this;
}
bitpack operator|(bitpack const& t) const { return bitpack(*this) |= t; }
bitpack operator&(bitpack const& t) const { return bitpack(*this) &= t; }

bitpack operator~() const { return {~a, ~b}; }
};
} // namespace sshash
39 changes: 20 additions & 19 deletions include/buckets.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

namespace sshash {

template <class kmer_t>
struct buckets {
std::pair<lookup_result, uint64_t> offset_to_id(uint64_t offset, uint64_t k) const {
auto [pos, contig_begin, contig_end] = pieces.locate(offset);
Expand Down Expand Up @@ -56,14 +57,14 @@ struct buckets {

kmer_t contig_prefix(uint64_t contig_id, uint64_t k) const {
uint64_t contig_begin = pieces.access(contig_id);
bit_vector_iterator bv_it(strings, 2 * contig_begin);
return bv_it.read(2 * (k - 1));
bit_vector_iterator<kmer_t> bv_it(strings, kmer_t::bits_per_char * contig_begin);
return bv_it.read(kmer_t::bits_per_char * (k - 1));
}

kmer_t contig_suffix(uint64_t contig_id, uint64_t k) const {
uint64_t contig_end = pieces.access(contig_id + 1);
bit_vector_iterator bv_it(strings, 2 * (contig_end - k + 1));
return bv_it.read(2 * (k - 1));
bit_vector_iterator<kmer_t> bv_it(strings, kmer_t::bits_per_char * (contig_end - k + 1));
return bv_it.read(kmer_t::bits_per_char * (k - 1));
}

std::pair<uint64_t, uint64_t> locate_bucket(uint64_t bucket_id) const {
Expand Down Expand Up @@ -94,10 +95,10 @@ struct buckets {
uint64_t m) const {
uint64_t offset = offsets.access(super_kmer_id);
auto [res, contig_end] = offset_to_id(offset, k);
bit_vector_iterator bv_it(strings, 2 * offset);
bit_vector_iterator<kmer_t> bv_it(strings, kmer_t::bits_per_char * offset);
uint64_t window_size = std::min<uint64_t>(k - m + 1, contig_end - offset - k + 1);
for (uint64_t w = 0; w != window_size; ++w) {
kmer_t read_kmer = bv_it.read_and_advance_by_two(2 * k);
kmer_t read_kmer = bv_it.read_and_advance_by_char(kmer_t::bits_per_char * k);
if (read_kmer == target_kmer) {
res.kmer_id += w;
res.kmer_id_in_contig += w;
Expand All @@ -119,10 +120,10 @@ struct buckets {
for (uint64_t super_kmer_id = begin; super_kmer_id != end; ++super_kmer_id) {
uint64_t offset = offsets.access(super_kmer_id);
auto [res, contig_end] = offset_to_id(offset, k);
bit_vector_iterator bv_it(strings, 2 * offset);
bit_vector_iterator<kmer_t> bv_it(strings, kmer_t::bits_per_char * offset);
uint64_t window_size = std::min<uint64_t>(k - m + 1, contig_end - offset - k + 1);
for (uint64_t w = 0; w != window_size; ++w) {
kmer_t read_kmer = bv_it.read_and_advance_by_two(2 * k);
kmer_t read_kmer = bv_it.read_and_advance_by_char(kmer_t::bits_per_char * k);
if (read_kmer == target_kmer) {
res.kmer_id += w;
res.kmer_id_in_contig += w;
Expand Down Expand Up @@ -170,8 +171,8 @@ struct buckets {

void access(uint64_t kmer_id, char* string_kmer, uint64_t k) const {
uint64_t offset = id_to_offset(kmer_id, k);
bit_vector_iterator bv_it(strings, 2 * offset);
kmer_t read_kmer = bv_it.read(2 * k);
bit_vector_iterator<kmer_t> bv_it(strings, kmer_t::bits_per_char * offset);
kmer_t read_kmer = bv_it.read(kmer_t::bits_per_char * k);
util::uint_kmer_to_string(read_kmer, string_kmer, k);
}

Expand All @@ -186,7 +187,7 @@ struct buckets {
, m_end_kmer_id(end_kmer_id)
, m_k(k) //
{
m_bv_it = bit_vector_iterator(m_buckets->strings, -1);
m_bv_it = bit_vector_iterator<kmer_t>(m_buckets->strings, -1);
m_offset = m_buckets->id_to_offset(m_begin_kmer_id, k);
auto [pos, piece_end] = m_buckets->pieces.next_geq(m_offset);
if (piece_end == m_offset) pos += 1;
Expand All @@ -209,12 +210,12 @@ struct buckets {
util::uint_kmer_to_string(m_read_kmer, m_ret.second.data(), m_k);
} else {
memmove(m_ret.second.data(), m_ret.second.data() + 1, m_k - 1);
m_ret.second[m_k - 1] = util::uint64_to_char(m_last_two_bits);
m_ret.second[m_k - 1] = kmer_t::uint64_to_char(m_last_char);
}
m_clear = false;
m_read_kmer >>= 2;
m_last_two_bits = m_bv_it.get_next_two_bits();
m_read_kmer += m_last_two_bits << (2 * (m_k - 1));
m_read_kmer.drop_char();
m_last_char = m_bv_it.get_next_char();
m_read_kmer.kth_char_or(m_k - 1, m_last_char);
++m_begin_kmer_id;
++m_offset;
return m_ret;
Expand All @@ -230,18 +231,18 @@ struct buckets {
uint64_t m_k;
uint64_t m_offset;
uint64_t m_next_offset;
bit_vector_iterator m_bv_it;
bit_vector_iterator<kmer_t> m_bv_it;
ef_sequence<true>::iterator m_pieces_it;

kmer_t m_read_kmer;
uint64_t m_last_two_bits;
uint64_t m_last_char;
bool m_clear;

void next_piece() {
m_bv_it.at(2 * m_offset);
m_bv_it.at(kmer_t::bits_per_char * m_offset);
m_next_offset = m_pieces_it.next();
assert(m_next_offset > m_offset);
m_read_kmer = m_bv_it.take(2 * m_k);
m_read_kmer = m_bv_it.take(kmer_t::bits_per_char * m_k);
m_clear = true;
}
};
Expand Down
14 changes: 8 additions & 6 deletions include/builder/build.cpp → include/builder/build.impl
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@

namespace sshash {

void dictionary::build(std::string const& filename, build_configuration const& build_config) {
template <class kmer_t>
void dictionary<kmer_t>::build(std::string const& filename,
build_configuration const& build_config) {
/* Validate the build configuration. */
if (build_config.k == 0) throw std::runtime_error("k must be > 0");
if (build_config.k > constants::max_k) {
throw std::runtime_error("k must be less <= " + std::to_string(constants::max_k) +
if (build_config.k > kmer_t::max_k) {
throw std::runtime_error("k must be less <= " + std::to_string(kmer_t::max_k) +
" but got k = " + std::to_string(build_config.k));
}
if (build_config.m == 0) throw std::runtime_error("m must be > 0");
if (build_config.m > constants::max_m) {
throw std::runtime_error("m must be less <= " + std::to_string(constants::max_m) +
if (build_config.m > kmer_t::max_m) {
throw std::runtime_error("m must be less <= " + std::to_string(kmer_t::max_m) +
" but got m = " + std::to_string(build_config.m));
}
if (build_config.m > build_config.k) throw std::runtime_error("m must be <= k");
Expand All @@ -41,7 +43,7 @@ void dictionary::build(std::string const& filename, build_configuration const& b

/* step 1: parse the input file and build compact string pool ***/
timer.start();
parse_data data = parse_file(filename, build_config);
parse_data<kmer_t> data = parse_file<kmer_t>(filename, build_config);
m_size = data.num_kmers;
timer.stop();
timings.push_back(timer.elapsed());
Expand Down
Loading

0 comments on commit 35492d8

Please sign in to comment.