Skip to content

Commit

Permalink
Merge pull request #92 from 3dem/jamaliki-patch-1
Browse files Browse the repository at this point in the history
Jamaliki patch 1
  • Loading branch information
jamaliki authored Nov 29, 2023
2 parents 8e0a567 + 67a9b96 commit 9271df3
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 30 deletions.
2 changes: 1 addition & 1 deletion model_angelo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
"""


__version__ = "1.0.11"
__version__ = "1.0.12"
46 changes: 38 additions & 8 deletions model_angelo/apps/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,13 @@ def main(parsed_args):

# Try to open FASTA --------------------------------------------------------------------------------------
from model_angelo.utils.fasta_utils import read_fasta

# if not is_valid_fasta_ending(parsed_args.protein_fasta):
# raise RuntimeError(
# f"File {parsed_args.protein_fasta} is not a supported file format."
# )

new_protein_fasta_path = write_fasta_only_aa(parsed_args.protein_fasta)
try:
read_fasta(new_protein_fasta_path)
except Exception as e:
raise RuntimeError(
f"File {parsed_args.protein_fasta} is not a valid FASTA file."
) from e

# Run C-alpha inference ----------------------------------------------------------------------------------------
print("--------------------- Initial C-alpha prediction ---------------------")
Expand Down Expand Up @@ -256,28 +256,58 @@ def main(parsed_args):

pruned_file_src = gnn_output.replace("output.cif", "output_fixed_aa_pruned.cif")
raw_file_src = gnn_output.replace("output.cif", "output_fixed_aa.cif")

name = os.path.basename(parsed_args.output_dir)
pruned_file_dst = os.path.join(parsed_args.output_dir, f"{name}.cif")
raw_file_dst = os.path.join(parsed_args.output_dir, f"{name}_raw.cif")

os.replace(pruned_file_src, pruned_file_dst)
os.replace(raw_file_src, raw_file_dst)

# Entropy files
os.makedirs(
os.path.join(parsed_args.output_dir, "entropy_scores"),
exist_ok=True,
)
pruned_es_file_src = gnn_output.replace("output.cif", "output_fixed_aa_pruned_entropy_score.cif")
raw_es_file_src = gnn_output.replace("output.cif", "output_fixed_aa_entropy_score.cif")

pruned_es_file_dst = os.path.join(parsed_args.output_dir, "entropy_scores", f"{name}.cif")
raw_es_file_dst = os.path.join(parsed_args.output_dir, "entropy_scores", f"{name}_raw.cif")

os.replace(pruned_es_file_src, pruned_es_file_dst)
os.replace(raw_es_file_src, raw_es_file_dst)

if not parsed_args.keep_intermediate_results:
for directory in os.listdir(parsed_args.output_dir):
if directory.startswith("gnn_output_round_") or directory == "see_alpha_output":
shutil.rmtree(os.path.join(parsed_args.output_dir, directory))

print("-" * 70)
print("ModelAngelo build has been completed successfully!")
print("-" * 70)
print(f"You can find your output mmCIF file here: {pruned_file_dst}")
print("-" * 70)
print(
f"The raw output without pruning might be useful to show "
f"The raw output without pruning might be useful to show \n"
f"some parts of the map that may be modelled, \n"
f"but could not be automatically modelled. \n"
f"However, the amino acid classifications are generally \n"
f"not going to be correct. \n"
f"You can find that here: {raw_file_dst}"
)
print("-" * 70)
print(
f"(Experimental) We now have CIF files with entropy scores \n"
f"for each residue. These are useful for identifying regions \n"
f"of the map that have lower confidence in the residue type \n"
f"(i.e. amino-acid or nucleotide base) identity. \n"
f"These files are named the same as the pruned and raw files, \n"
f"however the bfactor column has the new entropy score instead \n"
f"of the ModelAngelo predicted confidence, which is backbone-based. \n"
f"These files are in: {os.path.join(parsed_args.output_dir, 'entropy_scores')}"
)
print("-" * 70)
print("Enjoy!")

write_relion_job_exit_status(
Expand Down
22 changes: 22 additions & 0 deletions model_angelo/apps/build_no_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,36 @@ def main(parsed_args):

os.replace(raw_file_src, raw_file_dst)

# Entropy file
raw_es_file_src = gnn_output.replace("output.cif", "output_entropy_score.cif")
raw_es_file_dst = os.path.join(parsed_args.output_dir, f"{name}_entropy_score.cif")

os.replace(raw_es_file_src, raw_es_file_dst)

shutil.rmtree(hmm_profiles_dst, ignore_errors=True)
os.replace(hmm_profiles_src, hmm_profiles_dst)

if not parsed_args.keep_intermediate_results:
for directory in os.listdir(parsed_args.output_dir):
if directory.startswith("gnn_output_round_") or directory == "see_alpha_output":
shutil.rmtree(os.path.join(parsed_args.output_dir, directory))

print("-" * 70)
print("ModelAngelo build_no_seq has been completed successfully!")
print("-" * 70)
print(f"You can find your output mmCIF file here: {raw_file_dst}")
print("-" * 70)
print(
f"(Experimental) We now have a CIF file output with entropy scores \n"
f"for each residue. These are useful for identifying regions \n"
f"of the map that have lower confidence in the residue type \n"
f"(i.e. amino-acid or nucleotide base) identity. \n"
f"This is the same model as the one above, but the bfactor column \n"
f"has the new entropy score instead of the ModelAngelo predicted confidence,\n"
f"which is backbone-based."
f"This file is here: {raw_es_file_dst}"
)
print("-" * 70)
print(
f"The HMM profiles are available in the directory: {hmm_profiles_dst}\n"
f"They are named according to the chains found in {raw_file_dst}\n"
Expand Down
61 changes: 51 additions & 10 deletions model_angelo/gnn/flood_fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,13 @@ def chains_to_atoms(
.numpy()
)

(chain_all_atoms, chain_atom_mask, chain_bfactors, chain_aa_probs,) = (
[],
[],
[],
[],
)
(
chain_all_atoms,
chain_atom_mask,
chain_bfactors,
chain_entropy_scores,
chain_aa_probs,
) = ([], [], [], [], [])
# Everything below is in the order of chains
for chain_id in range(len(chains)):
chain_id_backbone_affine = backbone_affine[chains[chain_id]]
Expand All @@ -107,11 +108,15 @@ def chains_to_atoms(
)
* 100
)
chain_entropy_scores.append(
final_results["entropy_score"][existence_mask][chains[chain_id]] * 100
)
chain_aa_probs.append(aa_probs[chains[chain_id]])
return (
chain_all_atoms,
chain_atom_mask,
chain_bfactors,
chain_entropy_scores,
chain_aa_probs,
)

Expand Down Expand Up @@ -182,6 +187,13 @@ def final_results_to_cif(
backbone_affine = backbone_affine[existence_mask]
final_results["aa_logits"][prot_mask][..., num_prot:] = -100
final_results["aa_logits"][~prot_mask][..., :num_prot] = -100
# Calculate entropy score
exp_logits = np.exp(final_results["aa_logits"])
logit_probs = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True)
aa_entropy = -np.sum(final_results["aa_logits"] * logit_probs, axis=-1)
final_results["entropy_score"] = local_confidence_score_sigmoid(
- aa_entropy, best_value=5.0, worst_value=2.0, mid_point=3.0,
)
final_results["aa_logits"] /= temperature

torsion_angles = select_torsion_angles(
Expand All @@ -196,6 +208,7 @@ def final_results_to_cif(
)
* 100
)
entropy_scores = final_results["entropy_score"][existence_mask] * 100

if refine:
protein.atomc_positions = all_atoms
Expand Down Expand Up @@ -239,6 +252,13 @@ def final_results_to_cif(
cif_path,
bfactors=[bfactors[c] for c in chains],
)
chain_atom14_to_cif(
[aatype[c] for c in chains],
[all_atoms[c] for c in chains],
[atom_mask[c] for c in chains],
cif_path.replace(".cif", "_entropy_score.cif"),
bfactors=[entropy_scores[c] for c in chains],
)

chain_aa_logits = [final_results["aa_logits"][existence_mask][c] for c in chains]
chain_prot_mask = [prot_mask[c] for c in chains]
Expand Down Expand Up @@ -307,6 +327,7 @@ def final_results_to_cif(
chain_all_atoms,
chain_atom_mask,
chain_bfactors,
chain_entropy_scores,
chain_aa_probs,
) = chains_to_atoms(
final_results, fix_chains_output, backbone_affine, existence_mask
Expand All @@ -321,6 +342,13 @@ def final_results_to_cif(
cif_path.replace(".cif", "_fixed_aa.cif"),
bfactors=chain_bfactors,
)
chain_atom14_to_cif(
fix_chains_output.best_match_output.new_sequences,
chain_all_atoms,
chain_atom_mask,
cif_path.replace(".cif", "_fixed_aa_entropy_score.cif"),
bfactors=chain_entropy_scores,
)

write_chain_report(
cif_path.replace(".cif", "_chain_report.csv"),
Expand All @@ -343,6 +371,7 @@ def final_results_to_cif(
chain_all_atoms,
chain_atom_mask,
chain_bfactors,
chain_entropy_scores,
chain_aa_probs,
) = chains_to_atoms(
final_results, fix_chains_output, backbone_affine, existence_mask
Expand All @@ -359,6 +388,18 @@ def final_results_to_cif(
if aggressive_pruning
else None,
)

chain_atom14_to_cif(
fix_chains_output.best_match_output.new_sequences,
chain_all_atoms,
chain_atom_mask,
cif_path.replace(".cif", "_fixed_aa_pruned_entropy_score.cif"),
bfactors=chain_entropy_scores,
sequence_idxs=fix_chains_output.best_match_output.sequence_idxs,
res_idxs=fix_chains_output.best_match_output.residue_idxs
if aggressive_pruning
else None,
)

write_chain_probabilities(
cif_path.replace(".cif", "_aa_probabilities.aap"),
Expand Down Expand Up @@ -434,8 +475,8 @@ def flood_fill(

b_factors_copy[idx] = -1

og_chain_starts = np.array([c[0] for c in chains])
og_chain_ends = np.array([c[-1] for c in chains])
og_chain_starts = np.array([c[0] for c in chains], dtype=np.int32)
og_chain_ends = np.array([c[-1] for c in chains], dtype=np.int32)

chain_starts = og_chain_starts.copy()
chain_ends = og_chain_ends.copy()
Expand Down Expand Up @@ -488,8 +529,8 @@ def flood_fill(
tmp_chains.append(new_chain)
chains = tmp_chains

chain_starts = np.array([c[0] for c in chains])
chain_ends = np.array([c[-1] for c in chains])
chain_starts = np.array([c[0] for c in chains], dtype=np.int32)
chain_ends = np.array([c[-1] for c in chains], dtype=np.int32)

spent_starts.add(chain_start_match)
spent_ends.add(chain_end_match)
Expand Down
6 changes: 0 additions & 6 deletions model_angelo/gnn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,12 +175,6 @@ def infer(args):
final_results = get_final_nn_results(collated_results)
output_path = os.path.join(args.output_dir, "output.cif")

# For debugging eyes only
pickle_dump(final_results, os.path.join(args.output_dir, "final_results.pkl"))
dump_protein_to_prot(protein, os.path.join(args.output_dir, "protein.prot"))
pickle_dump(rna_sequences, os.path.join(args.output_dir, "rna_sequences.pkl"))
pickle_dump(dna_sequences, os.path.join(args.output_dir, "dna_sequences.pkl"))

final_results_to_cif(
final_results,
protein=protein,
Expand Down
8 changes: 4 additions & 4 deletions model_angelo/utils/hmm_sequence_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,8 @@ def sort_chains(
unique_seqs = np.unique(match_to_sequence.sequence_idxs)

og_chain_lens = np.array([len(c) for c in chains])
og_chain_starts = np.array([c[0] for c in chains])
og_chain_ends = np.array([c[-1] for c in chains])
og_chain_starts = np.array([c[0] for c in chains], dtype=np.int32)
og_chain_ends = np.array([c[-1] for c in chains], dtype=np.int32)

chain_starts = og_chain_starts.copy()
chain_ends = og_chain_ends.copy()
Expand Down Expand Up @@ -452,8 +452,8 @@ def sort_chains(
tmp_chain_ids.append(new_chain_id)
new_chain_ids = tmp_chain_ids

chain_starts = np.array([c[0] for c in chains])
chain_ends = np.array([c[-1] for c in chains])
chain_starts = np.array([c[0] for c in chains], dtype=np.int32)
chain_ends = np.array([c[-1] for c in chains], dtype=np.int32)

spent_starts.add(chain_start_match)
spent_ends.add(chain_end_match)
Expand Down
2 changes: 1 addition & 1 deletion model_angelo/utils/save_pdb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

class ModelAngeloMMCIFIO(MMCIFIO):
def _save_dict(self, out_file):
label_seq_id = deepcopy(self.dic["_atom_site.label_seq_id"])
label_seq_id = deepcopy(self.dic["_atom_site.auth_seq_id"])
auth_seq_id = deepcopy(self.dic["_atom_site.auth_seq_id"])
self.dic["_atom_site.label_seq_id"] = label_seq_id
self.dic["_atom_site.auth_seq_id"] = auth_seq_id
Expand Down

0 comments on commit 9271df3

Please sign in to comment.