Skip to content

Commit

Permalink
column aggregation with k-mer filtering based on their counts (#320)
Browse files Browse the repository at this point in the history
  • Loading branch information
karasikov authored Dec 16, 2021
1 parent a1e41fb commit ce96cef
Show file tree
Hide file tree
Showing 6 changed files with 232 additions and 21 deletions.
69 changes: 69 additions & 0 deletions metagraph/integration_tests/test_transform_anno.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

NUM_THREADS = 4


class TestColumnOperations(TestingBase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -70,6 +71,74 @@ def test_overlap(self):
self.assertEqual(res.returncode, 0)
self.assertEqual(156421, len(res.stdout.decode()))

def _check_aggregation_min(self, min_count, expected_density):
command = f'{METAGRAPH} transform_anno {self.annotation} -p {NUM_THREADS} \
--aggregate-columns --min-count {min_count} -o aggregated'

res = subprocess.run(command.split(), stdout=PIPE)
self.assertEqual(res.returncode, 0)

res = self._get_stats(f'-a aggregated{anno_file_extension[self.anno_repr]}')
self.assertEqual(res.returncode, 0)
out = res.stdout.decode().split('\n')[2:]
self.assertEqual('labels: 1', out[0])
self.assertEqual('objects: 46960', out[1])
self.assertEqual(f'density: {expected_density}', out[2])
self.assertEqual(f'representation: {self.anno_repr}', out[3])

def test_aggregate_columns(self):
self._check_aggregation_min(0, 1)
self._check_aggregation_min(1, 1)
self._check_aggregation_min(5, 0.0715077)
self._check_aggregation_min(10, 0.00344974)
self._check_aggregation_min(20, 0)

def _check_aggregation_min_max_value(self, min_count, max_value, expected_density):
command = f'{METAGRAPH} transform_anno {self.annotation} -p {NUM_THREADS} \
--aggregate-columns --min-count {min_count} --max-value {max_value} -o aggregated'

res = subprocess.run(command.split(), stdout=PIPE)
self.assertEqual(res.returncode, 0)

res = self._get_stats(f'-a aggregated{anno_file_extension[self.anno_repr]}')
self.assertEqual(res.returncode, 0)
out = res.stdout.decode().split('\n')[2:]
self.assertEqual('labels: 1', out[0])
self.assertEqual('objects: 46960', out[1])
self.assertEqual(f'density: {expected_density}', out[2])
self.assertEqual(f'representation: {self.anno_repr}', out[3])

def test_aggregate_columns_filtered(self):
self._check_aggregation_min_max_value(0, 0, 0)
self._check_aggregation_min_max_value(1, 0, 0)
self._check_aggregation_min_max_value(2, 0, 0)
self._check_aggregation_min_max_value(3, 0, 0)
self._check_aggregation_min_max_value(5, 0, 0)

self._check_aggregation_min_max_value(0, 1, 0.99704)
self._check_aggregation_min_max_value(1, 1, 0.99704)
self._check_aggregation_min_max_value(2, 1, 0.392994)
self._check_aggregation_min_max_value(3, 1, 0.183305)
self._check_aggregation_min_max_value(5, 1, 0.0715077)

self._check_aggregation_min_max_value(0, 2, 0.998807)
self._check_aggregation_min_max_value(1, 2, 0.998807)
self._check_aggregation_min_max_value(2, 2, 0.394825)
self._check_aggregation_min_max_value(3, 2, 0.183986)
self._check_aggregation_min_max_value(5, 2, 0.0715077)

self._check_aggregation_min_max_value(0, 5, 0.998999)
self._check_aggregation_min_max_value(1, 5, 0.998999)
self._check_aggregation_min_max_value(2, 5, 0.395315)
self._check_aggregation_min_max_value(3, 5, 0.184817)
self._check_aggregation_min_max_value(5, 5, 0.0715077)

self._check_aggregation_min_max_value(0, 1000, 1)
self._check_aggregation_min_max_value(1, 1000, 1)
self._check_aggregation_min_max_value(2, 1000, 0.395336)
self._check_aggregation_min_max_value(3, 1000, 0.184817)
self._check_aggregation_min_max_value(5, 1000, 0.0715077)


if __name__ == '__main__':
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,97 @@ ::load_column_values(const std::vector<std::string> &filenames,
exit(1);
}

template <typename Label>
void ColumnCompressed<Label>
::load_columns_and_values(const std::vector<std::string> &filenames,
const ColumnsValuesCallback &callback,
size_t num_threads) {
std::atomic<bool> error_occurred = false;

std::vector<uint64_t> offsets(filenames.size(), 0);

// load labels
#pragma omp parallel for num_threads(num_threads) schedule(dynamic)
for (size_t i = 1; i < filenames.size(); ++i) {
auto fname = make_suffix(filenames[i - 1], kExtension);
try {
offsets[i] = read_num_labels(fname);
} catch (...) {
logger->error("Can't load label encoder from {}", fname);
error_occurred = true;
}
}

if (error_occurred)
exit(1);

// compute global offsets (partial sums)
std::partial_sum(offsets.begin(), offsets.end(), offsets.begin());

// load annotations
#pragma omp parallel for num_threads(num_threads) schedule(dynamic)
for (size_t i = 0; i < filenames.size(); ++i) {
const auto &filename = make_suffix(filenames[i], kExtension);
logger->trace("Loading columns from {}", filename);
try {
std::ifstream in(filename, std::ios::binary);
if (!in)
throw std::ifstream::failure("can't open file");

std::ignore = load_number(in);

LabelEncoder<Label> label_encoder_load;
if (!label_encoder_load.load(in))
throw std::ifstream::failure("can't load label encoder");

if (!label_encoder_load.size()) {
logger->warn("No columns in {}", filename);
continue;
}

const auto &values_fname
= utils::remove_suffix(filename, ColumnCompressed<>::kExtension)
+ ColumnCompressed<>::kCountExtension;

std::ifstream values_in(values_fname, std::ios::binary);
if (!values_in)
throw std::ifstream::failure("can't open file " + values_fname);

for (size_t c = 0; c < label_encoder_load.size(); ++c) {
auto column = std::make_unique<bit_vector_smart>();

if (!column->load(in))
throw std::ifstream::failure("can't load next column");

sdsl::int_vector<> column_values;
try {
column_values.load(values_in);
} catch (...) {
logger->error("Can't load column values from {} for column {}",
values_fname, c);
throw;
}
if (column_values.size() != column->num_set_bits())
throw std::ifstream::failure("inconsistent size of the value vector");

callback(offsets[i] + c,
label_encoder_load.decode(c),
std::move(column), std::move(column_values));
}

} catch (const std::exception &e) {
logger->error("Caught exception when loading values for {}: {}", filename, e.what());
error_occurred = true;
} catch (...) {
logger->error("Unknown exception when loading values for {}", filename);
error_occurred = true;
}
}

if (error_occurred)
exit(1);
}

template <typename Label>
void ColumnCompressed<Label>::insert_rows(const std::vector<Index> &rows) {
assert(std::is_sorted(rows.begin(), rows.end()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ class ColumnCompressed : public MultiLabelEncoded<Label> {
const ValuesCallback &callback,
size_t num_threads = 1);

using ColumnsValuesCallback = std::function<void(uint64_t offset,
const Label &,
std::unique_ptr<bit_vector>&&,
sdsl::int_vector<>&&)>;
static void load_columns_and_values(const std::vector<std::string> &filenames,
const ColumnsValuesCallback &callback,
size_t num_threads = 1);

// Dump columns to separate files in human-readable format
bool dump_columns(const std::string &prefix, size_t num_threads = 1) const;

Expand Down
7 changes: 7 additions & 0 deletions metagraph/src/cli/config/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ Config::Config(int argc, char *argv[]) {
min_fraction = std::stod(get_value(i++));
} else if (!strcmp(argv[i], "--max-fraction")) {
max_fraction = std::stod(get_value(i++));
} else if (!strcmp(argv[i], "--min-value")) {
min_value = atoi(get_value(i++));
} else if (!strcmp(argv[i], "--max-value")) {
max_value = atoi(get_value(i++));
} else if (!strcmp(argv[i], "--mem-cap-gb")) {
memory_available = atof(get_value(i++));
} else if (!strcmp(argv[i], "--dump-text-anno")) {
Expand Down Expand Up @@ -1145,9 +1149,12 @@ void Config::print_usage(const std::string &prog_name, IdentityType identity) {

// fprintf(stderr, "\t-o --outfile-base [STR] basename of output file []\n");
fprintf(stderr, "\t --aggregate-columns \t\taggregate annotation columns into a bitmask (new column) [off]\n");
fprintf(stderr, "\t \t\t\tFormula: min-count <= \\sum_i 1{min-value <= c_i <= max-value} <= max-count\n");
fprintf(stderr, "\t --anno-label [STR]\t\tname of the aggregated output column [mask]\n");
fprintf(stderr, "\t --min-value [INT] \t\tmin value for filtering [1]\n");
fprintf(stderr, "\t --min-count [INT] \t\texclude k-mers appearing in fewer than this number of columns [1]\n");
fprintf(stderr, "\t --min-fraction [FLOAT] \texclude k-mers appearing in fewer than this fraction of columns [0.0]\n");
fprintf(stderr, "\t --max-value [INT] \t\tmax value for filtering [inf]\n");
fprintf(stderr, "\t --max-count [INT] \t\texclude k-mers appearing in more than this number of columns [inf]\n");
fprintf(stderr, "\t --max-fraction [FLOAT] \texclude k-mers appearing in more than this fraction of columns [1.0]\n");
fprintf(stderr, "\t --compute-overlap [STR] \tcompute the number of shared bits in columns of this annotation and ANNOTATOR [off]\n");
Expand Down
2 changes: 2 additions & 0 deletions metagraph/src/cli/config/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ class Config {
double memory_available = 1;
unsigned int min_count = 1;
unsigned int max_count = std::numeric_limits<unsigned int>::max();
unsigned int min_value = 1;
unsigned int max_value = std::numeric_limits<unsigned int>::max();
unsigned int num_top_labels = -1;
unsigned int genome_binsize_anno = 1000;
unsigned int arity_brwt = 2;
Expand Down
76 changes: 55 additions & 21 deletions metagraph/src/cli/transform_annotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,9 @@ int transform_annotation(Config *config) {
exit(1);
}

const bool filter_values = config->min_value > 1
|| config->max_value
< std::numeric_limits<unsigned int>::max();
const uint64_t min_cols
= std::max((uint64_t)std::ceil(num_columns * config->min_fraction),
(uint64_t)config->min_count);
Expand All @@ -511,36 +514,67 @@ int transform_annotation(Config *config) {
(uint64_t)config->max_count);

// TODO: set width to log2(max_cols + 1) but make sure atomic
// increments can't lead to overflow
// increments don't overflow
sdsl::int_vector<> sum(0, 0, sdsl::bits::hi(num_columns) + 1);
ProgressBar progress_bar(num_columns, "Intersect columns",
std::cerr, !get_verbose());
ProgressBar progress_bar(num_columns, "Intersect columns", std::cerr, !get_verbose());
ThreadPool thread_pool(get_num_threads(), 1);
std::mutex mu_sum;
std::mutex mu;
auto on_column = [&](uint64_t, const auto &, auto&& col) {
std::lock_guard<std::mutex> lock(mu);

if (!sum.size()) {
sum.resize(col->size());
} else if (sum.size() != col->size()) {
logger->error("Input columns have inconsistent size ({} != {})",
sum.size(), col->size());
exit(1);
}
if (filter_values) {
auto on_column = [&](uint64_t, const std::string &,
std::unique_ptr<bit_vector>&& col,
sdsl::int_vector<>&& values) {
std::lock_guard<std::mutex> lock(mu);

thread_pool.enqueue([&,col{std::move(col)}]() {
col->call_ones([&](uint64_t i) {
atomic_fetch_and_add(sum, i, 1, mu_sum, __ATOMIC_RELAXED);
if (!sum.size()) {
sum.resize(col->size());
sdsl::util::set_to_value(sum, 0);
} else if (sum.size() != col->size()) {
logger->error("Input columns have inconsistent size ({} != {})",
sum.size(), col->size());
exit(1);
}

thread_pool.enqueue([&,col(std::move(col)),values(std::move(values))]() {
assert(col->num_set_bits() == values.size());
for (uint64_t r = 0; r < values.size(); ++r) {
if (values[r] >= config->min_value && values[r] <= config->max_value)
atomic_fetch_and_add(sum, col->select1(r + 1), 1, mu_sum, __ATOMIC_RELAXED);
}
++progress_bar;
});
++progress_bar;
});
};
};

if (!ColumnCompressed<>::merge_load(files, on_column, get_num_threads())) {
logger->error("Couldn't load annotations");
exit(1);
ColumnCompressed<>::load_columns_and_values(files, on_column, get_num_threads());

} else {
auto on_column = [&](uint64_t, const auto &, auto&& col) {
std::lock_guard<std::mutex> lock(mu);

if (!sum.size()) {
sum.resize(col->size());
sdsl::util::set_to_value(sum, 0);
} else if (sum.size() != col->size()) {
logger->error("Input columns have inconsistent size ({} != {})",
sum.size(), col->size());
exit(1);
}

thread_pool.enqueue([&,col{std::move(col)}]() {
col->call_ones([&](uint64_t i) {
atomic_fetch_and_add(sum, i, 1, mu_sum, __ATOMIC_RELAXED);
});
++progress_bar;
});
};

if (!ColumnCompressed<>::merge_load(files, on_column, get_num_threads())) {
logger->error("Couldn't load annotations");
exit(1);
}
}

thread_pool.join();
std::atomic_thread_fence(std::memory_order_acquire);

Expand Down

0 comments on commit ce96cef

Please sign in to comment.