Skip to content

Commit

Permalink
update branch
Browse files Browse the repository at this point in the history
  • Loading branch information
fernandomeyer committed Mar 1, 2021
1 parent 460e115 commit ac98e60
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 21 deletions.
31 changes: 11 additions & 20 deletions opal.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,13 @@ def print_by_tool(output_dir, pd_metrics):
table.fillna('na').to_csv(os.path.join(output_dir, "by_tool", toolname + ".tsv"), sep='\t')


def compute_metrics(sample_metadata, profile, gs_pf_profile, gs_rank_to_taxid_to_percentage, rank_to_taxid_to_percentage,
normalize, branch_length_fun):
def compute_metrics(sample_metadata, profile, gs_pf_profile, gs_rank_to_taxid_to_percentage, rank_to_taxid_to_percentage):
# Unifrac
if isinstance(profile, PF.Profile):
pf_profile = profile
else:
pf_profile = PF.Profile(sample_metadata=sample_metadata, profile=profile, branch_length_fun=branch_length_fun)
unifrac = uf.compute_unifrac(gs_pf_profile, pf_profile, normalize)
pf_profile = PF.Profile(sample_metadata=sample_metadata, profile=profile)
unifrac = uf.compute_unifrac(gs_pf_profile, pf_profile)

# Shannon
shannon = sh.compute_shannon_index(rank_to_taxid_to_percentage)
Expand Down Expand Up @@ -149,22 +148,20 @@ def load_profiles(gold_standard_file, profiles_files, normalize):
return sample_ids_list, gs_samples_list, profiles_list_to_samples_list


def evaluate(gs_samples_list, profiles_list_to_samples_list, labels, normalize, filter_tail_percentage, branch_length_fun):
def evaluate(gs_samples_list, profiles_list_to_samples_list, labels, filter_tail_percentage):
gs_id_to_rank_to_taxid_to_percentage = {}
gs_id_to_pf_profile = {}
pd_metrics = pd.DataFrame()

for sample in gs_samples_list:
sample_id, sample_metadata, profile = sample
gs_id_to_rank_to_taxid_to_percentage[sample_id] = load_data.get_rank_to_taxid_to_percentage(profile)
gs_id_to_pf_profile[sample_id] = PF.Profile(sample_metadata=sample_metadata, profile=profile, branch_length_fun=branch_length_fun)
gs_id_to_pf_profile[sample_id] = PF.Profile(sample_metadata=sample_metadata, profile=profile)
unifrac, shannon, l1norm, binary_metrics, braycurtis = compute_metrics(sample_metadata,
gs_id_to_pf_profile[sample_id],
gs_id_to_pf_profile[sample_id],
gs_id_to_rank_to_taxid_to_percentage[sample_id],
gs_id_to_rank_to_taxid_to_percentage[sample_id],
normalize,
branch_length_fun)
gs_id_to_rank_to_taxid_to_percentage[sample_id])
pd_metrics = pd.concat([pd_metrics, reformat_pandas(sample_id, c.GS, braycurtis, shannon, binary_metrics, l1norm, unifrac)], ignore_index=True)
if filter_tail_percentage:
metrics_list = pd_metrics['metric'].unique().tolist()
Expand All @@ -189,20 +186,17 @@ def evaluate(gs_samples_list, profiles_list_to_samples_list, labels, normalize,

unifrac, shannon, l1norm, binary_metrics, braycurtis = compute_metrics(sample_metadata, profile, gs_pf_profile,
gs_rank_to_taxid_to_percentage,
rank_to_taxid_to_percentage,
normalize,
branch_length_fun)
rank_to_taxid_to_percentage)
rename_as_unfiltered = True if filter_tail_percentage else False
pd_metrics = pd.concat([pd_metrics, reformat_pandas(sample_id, label, braycurtis, shannon, binary_metrics, l1norm, unifrac, rename_as_unfiltered)], ignore_index=True)

if filter_tail_percentage:
rank_to_taxid_to_percentage_filtered = \
load_data.get_rank_to_taxid_to_percentage_filtered(rank_to_taxid_to_percentage, filter_tail_percentage)
unifrac, shannon, l1norm, binary_metrics, braycurtis = compute_metrics(sample_metadata, profile, gs_pf_profile,
profile_filtered = [prediction for prediction in profile if prediction.taxid in rank_to_taxid_to_percentage_filtered[prediction.rank]]
unifrac, shannon, l1norm, binary_metrics, braycurtis = compute_metrics(sample_metadata, profile_filtered, gs_pf_profile,
gs_rank_to_taxid_to_percentage,
rank_to_taxid_to_percentage_filtered,
normalize,
branch_length_fun)
rank_to_taxid_to_percentage_filtered)
pd_metrics = pd.concat([pd_metrics, reformat_pandas(sample_id, label, braycurtis, shannon, binary_metrics, l1norm, unifrac)], ignore_index=True)

one_profile_assessed = True
Expand Down Expand Up @@ -319,7 +313,6 @@ def main():
group2 = parser.add_argument_group('optional arguments')
group2.add_argument('-n', '--normalize', help='Normalize samples', action='store_true')
group2.add_argument('-f', '--filter', help='Filter out the predictions with the smallest relative abundances summing up to [FILTER]%% within a rank (affects only precision, default: 0)', type=float)
group2.add_argument('-b', '--branch_length_function', help='UniFrac tree branch length function (default: "lambda x: 1/x", x=tree depth)', required=False, default='lambda x: 1/x')
group2.add_argument('-p', '--plot_abundances', help='Plot abundances in the gold standard (can take some minutes)', action='store_true')
group2.add_argument('-l', '--labels', help='Comma-separated profiles names', required=False)
group2.add_argument('-t', '--time', help='Comma-separated runtimes in hours', required=False)
Expand Down Expand Up @@ -356,9 +349,7 @@ def main():
pd_metrics = evaluate(gs_samples_list,
profiles_list_to_samples_list,
labels,
args.normalize,
args.filter,
uf.get_branch_length_function(args.branch_length_function))
args.filter)
time_list, memory_list = get_time_memory(args.time, args.memory, args.profiles_files)
if time_list or memory_list:
pd_metrics = concat_time_memory(labels, time_list, memory_list, pd_metrics)
Expand Down
2 changes: 1 addition & 1 deletion version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.0.9'
__version__ = '1.0.10'

0 comments on commit ac98e60

Please sign in to comment.