Skip to content

Commit

Permalink
compute confusion table, support compressed files
Browse files Browse the repository at this point in the history
  • Loading branch information
fernandomeyer committed Sep 23, 2023
1 parent 28975f9 commit 0eef537
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 23 deletions.
84 changes: 67 additions & 17 deletions cami_opal/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,19 +75,22 @@ def reformat_pandas(sample_id, label, braycurtis, shannon, binary_metrics, l1nor
return pd.DataFrame()

# convert Unifrac
pd_unifrac = pd.DataFrame(index=[sample_id], data=[unifrac], columns=[c.UNIFRAC, c.UNW_UNIFRAC]).stack().reset_index()
pd_unifrac = pd.DataFrame(index=[sample_id], data=[unifrac],
columns=[c.UNIFRAC, c.UNW_UNIFRAC]).stack().reset_index()
pd_unifrac.columns = ['sample', 'metric', 'value']
pd_unifrac['rank'] = np.nan
pd_unifrac['tool'] = label

# convert Unifrac CAMI
pd_unifrac_cami = pd.DataFrame(index=[sample_id], data=[unifrac_cami], columns=[c.UNIFRAC_CAMI, c.UNW_UNIFRAC_CAMI]).stack().reset_index()
pd_unifrac_cami = pd.DataFrame(index=[sample_id], data=[unifrac_cami],
columns=[c.UNIFRAC_CAMI, c.UNW_UNIFRAC_CAMI]).stack().reset_index()
pd_unifrac_cami.columns = ['sample', 'metric', 'value']
pd_unifrac_cami['rank'] = np.nan
pd_unifrac_cami['tool'] = label

# convert Shannon
pd_shannon = pd.DataFrame([shannon[rank].get_pretty_dict() for rank in shannon.keys()]).set_index('rank').stack().reset_index().rename(columns={'level_1': 'metric', 0: 'value'})
pd_shannon = pd.DataFrame([shannon[rank].get_pretty_dict() for rank in shannon.keys()]).set_index(
'rank').stack().reset_index().rename(columns={'level_1': 'metric', 0: 'value'})
pd_shannon['metric'].replace(['diversity', 'equitability'], [c.SHANNON_DIVERSITY, c.SHANNON_EQUIT], inplace=True)
pd_shannon['sample'] = sample_id
pd_shannon['tool'] = label
Expand All @@ -99,7 +102,9 @@ def reformat_pandas(sample_id, label, braycurtis, shannon, binary_metrics, l1nor
pd_l1norm['metric'] = c.L1NORM

# convert Binary metrics
pd_binary_metrics = pd.DataFrame([binary_metrics[rank].get_pretty_dict() for rank in binary_metrics.keys()]).set_index('rank').stack().reset_index().rename(columns={'level_1': 'metric', 0: 'value'})
pd_binary_metrics = pd.DataFrame(
[binary_metrics[rank].get_pretty_dict() for rank in binary_metrics.keys()]).set_index(
'rank').stack().reset_index().rename(columns={'level_1': 'metric', 0: 'value'})
pd_binary_metrics['metric'].replace(['fp', 'tp', 'fn', 'jaccard', 'precision', 'recall', 'f1'],
[c.FP, c.TP, c.FN, c.JACCARD, c.PRECISION, c.RECALL, c.F1_SCORE],
inplace=True)
Expand Down Expand Up @@ -128,18 +133,38 @@ def reformat_pandas(sample_id, label, braycurtis, shannon, binary_metrics, l1nor
pd_ntaxa = pd_ntaxa[pd_ntaxa['rank'].isin(c.ALL_RANKS)]
pd_ntaxa['value'] = pd_ntaxa['value']

pd_formatted = pd.concat([pd_braycurtis, pd_shannon, pd_binary_metrics, pd_l1norm, pd_unifrac, pd_unifrac_cami, pd_sum, pd_ntaxa], ignore_index=True, sort=False)
pd_formatted = pd.concat(
[pd_braycurtis, pd_shannon, pd_binary_metrics, pd_l1norm, pd_unifrac, pd_unifrac_cami, pd_sum, pd_ntaxa],
ignore_index=True, sort=False)

if rename_as_unfiltered:
metrics_list = pd_formatted['metric'].unique().tolist()
pd_formatted['metric'].replace(metrics_list, [metric + c.UNFILTERED_SUF for metric in metrics_list], inplace=True)
pd_formatted['metric'].replace(metrics_list, [metric + c.UNFILTERED_SUF for metric in metrics_list],
inplace=True)

return pd_formatted


def get_confusion_df(gs_rank_to_taxid_to_percentage, rank_to_taxid_to_percentage):
def dict_to_pandas(pct_dict):
return pd.DataFrame.from_dict(pct_dict).stack().reset_index().dropna(subset=[0]).rename(
columns={'level_0': 'taxid', 'level_1': 'rank', 0: 'pct'})

df_gs = dict_to_pandas(gs_rank_to_taxid_to_percentage)
df_pred = dict_to_pandas(rank_to_taxid_to_percentage)
df_pred['classification'] = df_pred['taxid'].isin(df_gs['taxid'])
df_pred['classification'] = np.where(df_pred['classification'], 'TP', 'FP')
df_pred = pd.merge(df_pred, df_gs, on='taxid', sort=False, how='outer')
df_pred['classification'] = df_pred['classification'].fillna('FN')
df_pred['rank_x'] = np.where(df_pred['rank_x'].isna(), df_pred['rank_y'], df_pred['rank_x'])
df_pred = df_pred.rename(columns={'rank_x': 'rank', 'pct_x': 'pct', 'pct_y': 'pct_gs'}).drop('rank_y', axis=1)
return df_pred


def evaluate_gs(gs_samples_list, filter_tail_percentage, branch_length_fun, normalized_unifrac,
gs_id_to_rank_to_taxid_to_percentage, gs_id_to_pf_profile, gs_id_to_pf_profile_cami, skip_gs=False):
pd_metrics = pd.DataFrame()
pd_confusion = pd.DataFrame()
for sample in gs_samples_list:
sample_id, sample_metadata, profile = sample
gs_id_to_rank_to_taxid_to_percentage[sample_id] = load_data.get_rank_to_taxid_to_percentage(profile)
Expand All @@ -153,13 +178,23 @@ def evaluate_gs(gs_samples_list, filter_tail_percentage, branch_length_fun, norm
gs_id_to_rank_to_taxid_to_percentage[sample_id],
gs_id_to_rank_to_taxid_to_percentage[sample_id],
branch_length_fun, normalized_unifrac)
pd_metrics = pd.concat([pd_metrics, reformat_pandas(sample_id, c.GS, braycurtis, shannon, binary_metrics, l1norm, unifrac, unifrac_cami, rank_to_sum, rank_to_ntaxa)], ignore_index=True)
pd_metrics = pd.concat([pd_metrics,
reformat_pandas(sample_id, c.GS, braycurtis, shannon, binary_metrics, l1norm,
unifrac, unifrac_cami, rank_to_sum, rank_to_ntaxa)],
ignore_index=True)

pd_sample_confusion = get_confusion_df(gs_id_to_rank_to_taxid_to_percentage[sample_id], gs_id_to_rank_to_taxid_to_percentage[sample_id])
pd_sample_confusion['sample'] = sample_id
pd_sample_confusion['tool'] = c.GS
pd_confusion = pd.concat([pd_confusion, pd_sample_confusion], ignore_index=True)

if filter_tail_percentage and not skip_gs:
metrics_list = pd_metrics['metric'].unique().tolist()
pd_metrics_copy = pd_metrics.copy()
pd_metrics_copy['metric'].replace(metrics_list, [metric + c.UNFILTERED_SUF for metric in metrics_list], inplace=True)
pd_metrics_copy['metric'].replace(metrics_list, [metric + c.UNFILTERED_SUF for metric in metrics_list],
inplace=True)
pd_metrics = pd.concat([pd_metrics, pd_metrics_copy], ignore_index=True)
return pd_metrics
return pd_metrics, pd_confusion


def evaluate_main(gs_samples_list, profiles_list_to_samples_list, labels, filter_tail_percentage, branch_length,
Expand All @@ -169,8 +204,9 @@ def evaluate_main(gs_samples_list, profiles_list_to_samples_list, labels, filter
gs_id_to_pf_profile_cami = {}
branch_length_fun = PF.Profile.get_branch_length_function(branch_length)

pd_metrics = evaluate_gs(gs_samples_list, filter_tail_percentage, branch_length_fun, normalized_unifrac,
gs_id_to_rank_to_taxid_to_percentage, gs_id_to_pf_profile, gs_id_to_pf_profile_cami, skip_gs)
pd_metrics, pd_confusion = evaluate_gs(gs_samples_list, filter_tail_percentage, branch_length_fun, normalized_unifrac,
gs_id_to_rank_to_taxid_to_percentage, gs_id_to_pf_profile, gs_id_to_pf_profile_cami,
skip_gs)

one_profile_assessed = False
for samples_list, label in zip(profiles_list_to_samples_list, labels):
Expand All @@ -183,7 +219,9 @@ def evaluate_main(gs_samples_list, profiles_list_to_samples_list, labels, filter
gs_pf_profile = gs_id_to_pf_profile[sample_id]
gs_pf_profile_cami = gs_id_to_pf_profile_cami[sample_id]
else:
logging.getLogger('opal').warning("Skipping assessment of {} for sample {}. Make sure the SampleID of the gold standard and the profile are identical.\n".format(label, sample_id))
logging.getLogger('opal').warning(
"Skipping assessment of {} for sample {}. Make sure the SampleID of the gold standard and the profile are identical.\n".format(
label, sample_id))
continue

rank_to_taxid_to_percentage = load_data.get_rank_to_taxid_to_percentage(profile)
Expand All @@ -195,24 +233,36 @@ def evaluate_main(gs_samples_list, profiles_list_to_samples_list, labels, filter
rank_to_taxid_to_percentage,
branch_length_fun, normalized_unifrac)
rename_as_unfiltered = True if filter_tail_percentage else False
pd_metrics = pd.concat([pd_metrics, reformat_pandas(sample_id, label, braycurtis, shannon, binary_metrics, l1norm, unifrac, unifrac_cami, rank_to_sum, rank_to_ntaxa, rename_as_unfiltered)], ignore_index=True)
pd_metrics = pd.concat([pd_metrics,
reformat_pandas(sample_id, label, braycurtis, shannon, binary_metrics, l1norm,
unifrac, unifrac_cami, rank_to_sum, rank_to_ntaxa,
rename_as_unfiltered)], ignore_index=True)
pd_sample_confusion = get_confusion_df(gs_rank_to_taxid_to_percentage, rank_to_taxid_to_percentage)
pd_sample_confusion['sample'] = sample_id
pd_sample_confusion['tool'] = label
pd_confusion = pd.concat([pd_confusion, pd_sample_confusion], ignore_index=True)

if filter_tail_percentage:
rank_to_taxid_to_percentage_filtered = \
load_data.get_rank_to_taxid_to_percentage_filtered(rank_to_taxid_to_percentage, filter_tail_percentage)
profile_filtered = [prediction for prediction in profile if prediction.taxid in rank_to_taxid_to_percentage_filtered[prediction.rank]]
load_data.get_rank_to_taxid_to_percentage_filtered(rank_to_taxid_to_percentage,
filter_tail_percentage)
profile_filtered = [prediction for prediction in profile if
prediction.taxid in rank_to_taxid_to_percentage_filtered[prediction.rank]]
unifrac, unifrac_cami, shannon, l1norm, binary_metrics, braycurtis, rank_to_sum, rank_to_ntaxa = \
compute_metrics(sample_metadata, profile_filtered, gs_pf_profile,
profile_filtered, gs_pf_profile_cami,
gs_rank_to_taxid_to_percentage,
rank_to_taxid_to_percentage_filtered,
branch_length_fun, normalized_unifrac)
pd_metrics = pd.concat([pd_metrics, reformat_pandas(sample_id, label, braycurtis, shannon, binary_metrics, l1norm, unifrac, unifrac_cami, rank_to_sum, rank_to_ntaxa)], ignore_index=True)
pd_metrics = pd.concat([pd_metrics,
reformat_pandas(sample_id, label, braycurtis, shannon, binary_metrics, l1norm,
unifrac, unifrac_cami, rank_to_sum, rank_to_ntaxa)],
ignore_index=True)

one_profile_assessed = True

if not one_profile_assessed:
logging.getLogger('opal').critical("No profile could be evaluated.")
exit(1)

return pd_metrics
return pd_metrics, pd_confusion
4 changes: 2 additions & 2 deletions cami_opal/utils/ProfilingToolsCAMI.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def parse_file(self):
ancestor = tax_path[-2]
_data[tax_id]["branch_length"] = 1
i = -3
while ancestor is "" or ancestor == tax_id: # if it's a blank or repeated, go up until finding ancestor
while ancestor == "" or ancestor == tax_id: # if it's a blank or repeated, go up until finding ancestor
ancestor = tax_path[i]
_data[tax_id]["branch_length"] += 1
i -= 1
Expand Down Expand Up @@ -115,7 +115,7 @@ def _delete_missing(self):
ancestor = tax_path[-2]
_data[key]["branch_length"] = 1
i = -3
while ancestor is "" or ancestor == key: # if it's a blank or repeated, go up until finding ancestor
while ancestor == "" or ancestor == key: # if it's a blank or repeated, go up until finding ancestor
if i < -len(tax_path): # Path is all the way full with bad tax_ids, connect to root
_data[key]["branch_length"] += 1
ancestor = "-1"
Expand Down
27 changes: 24 additions & 3 deletions cami_opal/utils/load_data.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#!/usr/bin/env python

import os
import logging
import mimetypes
import gzip
import io
import tarfile
import zipfile
from collections import defaultdict


Expand Down Expand Up @@ -94,6 +97,24 @@ def normalize_samples(samples_list):
prediction.percentage = (prediction.percentage / sum_per_rank[prediction.rank]) * 100.0


def open_generic(file):
file_type, file_encoding = mimetypes.guess_type(file)

if file_encoding == 'gzip':
if file_type == 'application/x-tar': # .tar.gz
tar = tarfile.open(file, 'r:gz')
f = tar.extractfile(tar.getmembers()[0])
return io.TextIOWrapper(f)
else: # .gz
return gzip.open(file, 'rt')
if file_type == 'application/zip': # .zip
f = zipfile.ZipFile(file, 'r')
f = f.open(f.namelist()[0])
return io.TextIOWrapper(f)
else:
return open(file, 'rt')


def open_profile_from_tsv(file_path, normalize):
header = {}
column_name_to_index = {}
Expand All @@ -103,7 +124,7 @@ def open_profile_from_tsv(file_path, normalize):
reading_data = False
got_column_indices = False

with open(file_path) as read_handler:
with open_generic(file_path) as read_handler:
for line in read_handler:
if len(line.strip()) == 0 or line.startswith("#"):
continue
Expand Down
4 changes: 3 additions & 1 deletion opal.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,19 +192,21 @@ def main():
logger.info('done')

logger.info('Computing metrics...')
pd_metrics = evaluate.evaluate_main(gs_samples_list,
pd_metrics, pd_confusion = evaluate.evaluate_main(gs_samples_list,
profiles_list_to_samples_list,
labels,
args.filter,
args.branch_length_function,
args.normalized_unifrac)

time_list, memory_list = get_time_memory(args.time, args.memory, args.profiles_files)
if time_list or memory_list:
pd_metrics = concat_time_memory(labels, time_list, memory_list, pd_metrics)
logger.info('done')

logger.info('Saving computed metrics...')
pd_metrics[['tool', 'rank', 'metric', 'sample', 'value']].fillna('na').to_csv(os.path.join(output_dir, 'results.tsv'), sep='\t', index=False)
pd_confusion.to_csv(os.path.join(output_dir, 'confusion.tsv'), sep='\t', index=False)
print_by_tool(output_dir, pd_metrics)
print_by_rank(output_dir, labels, pd_metrics)
logger.info('done')
Expand Down

0 comments on commit 0eef537

Please sign in to comment.