Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix residues #82

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 134 additions & 37 deletions design_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,90 @@
from aposteriori.data_prep.create_frame_data_set import DatasetMetadata


def customize_fixed_residues(
pdb_to_sequence, pdb_to_real_sequence, chains_to_customize, res_to_fix
):
"""

A function to customize TIMED predictions. When user gives --res_to_fix and --chains_to_customize
commands, the residues in a given chain will be fixed to the input structure residues, and rest of the protein
will be predicted normally.

Parameters
----------
pdb_to_sequence: dict
Sequence as predicted by TIMED
pdb_to_real_sequence: dict
WT sequence
chains_to_customize: tuple
User specified chains where fixing of residues will be applied.
res_to_fix: tuple
Residue numbers for the fixed residues. This assumes that when the chain changes the residue number also starts from 1

"""
mapping = {chain: res_to_fix[i] for i, chain in enumerate(chains_to_customize)}
print(mapping)
for chain, res_tuples in mapping.items():
for key, value in pdb_to_sequence.items():
if key.endswith(chain) and key in pdb_to_real_sequence:
real_sequence = pdb_to_real_sequence[key]
sequence = list(value)
for res_tuple in res_tuples:
num = res_tuple
# Specified residues should change in the prediction.
if 0 <= num - 1 < min(len(sequence), len(real_sequence)):
pdb_to_sequence[key] = (
pdb_to_sequence[key][: num - 1]
+ pdb_to_real_sequence[key][num - 1]
+ pdb_to_sequence[key][num:]
)


def customize_predicted_residues(
pdb_to_sequence, pdb_to_real_sequence, chains_to_customize, res_to_predict
):

"""
A function to customize TIMED predictions. When user gives --res_to_predict and --chains_to_customize
commands, the residues in a given chain will be predicted and rest of the protein will be converted back to WT.

Parameters
----------
pdb_to_sequence: dict
Sequence as predicted by TIMED
pdb_to_real_sequence: dict
WT sequence
chains_to_customize: tuple
User specified chains where predictions will be applied. Unspecified chains will be converted back to WT.
res_to_predict: tuple
Residue numbers for the predicting residues. This assumes that when the chain changes the residue number also starts from 1

"""

mapping = {chain: res_to_predict[i] for i, chain in enumerate(chains_to_customize)}
for chain, res_tuples in mapping.items():
for key, value in pdb_to_sequence.items():
if key.endswith(chain) and key in pdb_to_real_sequence:
sequence = list(value)
for num in range(1, len(sequence)):
# Specified residues should be kept as they are in pdb_to_sequence but rest should change.
if num not in res_tuples:
pdb_to_sequence[key] = (
pdb_to_sequence[key][: num - 1]
+ pdb_to_real_sequence[key][num - 1]
+ pdb_to_sequence[key][num:]
)
# If the chain is not specified to be predicted by TIMED, convert it back to WT.
elif not key.endswith(chain) and key in pdb_to_real_sequence:
sequence = list(value)
for num in range(1, len(sequence)):
pdb_to_sequence[key] = (
pdb_to_sequence[key][: num - 1]
+ pdb_to_real_sequence[key][num - 1]
+ pdb_to_sequence[key][num:]
)


def rm_tree(pth: Path):
# Removes all files in a directory and the directory. From https://stackoverflow.com/questions/50186904/pathlib-recursively-remove-directory
pth = Path(pth)
Expand Down Expand Up @@ -207,17 +291,10 @@ def load_datasetmap(path_to_datasetmap: Path, is_old: bool = False) -> np.ndarra
path_to_datasetmap.suffix == ".txt"
), f"Expected Path {path_to_datasetmap} to be a .txt file but got {path_to_datasetmap.suffix}."
if is_old:
dataset_map = np.genfromtxt(
path_to_datasetmap,
delimiter=",",
dtype=str,
)
dataset_map = np.genfromtxt(path_to_datasetmap, delimiter=",", dtype=str)
else:
dataset_map = np.genfromtxt(
path_to_datasetmap,
delimiter=" ",
dtype=str,
skip_header=3,
path_to_datasetmap, delimiter=" ", dtype=str, skip_header=3
)
dataset_map = np.asarray(dataset_map)
# If list only contains 1 pdb, it fails to create a list of list [pdb_code, count]
Expand Down Expand Up @@ -485,8 +562,7 @@ def compress_rotamer_predictions_to_20(prediction_matrix: np.ndarray) -> np.ndar


def load_batch(
dataset_path: Path,
data_point_batch: t.List[t.Tuple],
dataset_path: Path, data_point_batch: t.List[t.Tuple]
) -> (np.ndarray, np.ndarray):
"""
Load batch from a dataset map.
Expand Down Expand Up @@ -531,9 +607,7 @@ def load_batch(


def convert_dataset_map_for_srb(
flat_dataset_map: list,
model_name: str,
path_to_output: Path = Path.cwd(),
flat_dataset_map: list, model_name: str, path_to_output: Path = Path.cwd()
):
"""
Converts datasetmap for compatibility with PDBench / Sequence recovery benchmark
Expand Down Expand Up @@ -593,7 +667,7 @@ def save_consensus_probs(


def save_dict_to_fasta(
pdb_to_sequence: dict, model_name: str, path_to_output: Path = Path.cwd(),
pdb_to_sequence: dict, model_name: str, path_to_output: Path = Path.cwd()
):
"""
Saves a dictionary of protein sequences to a fasta file.
Expand All @@ -615,6 +689,9 @@ def save_dict_to_fasta(

def extract_sequence_from_pred_matrix(
flat_dataset_map: t.List[t.Tuple],
res_to_predict: t.Tuple[t.Tuple],
res_to_fix: t.Tuple[t.Tuple],
chains_to_customize: t.Tuple[t.Tuple],
prediction_matrix: np.ndarray,
rotamers_categories: t.List[str],
old_datasetmap: bool = False,
Expand Down Expand Up @@ -664,21 +741,17 @@ def extract_sequence_from_pred_matrix(
chain = None
# Add support for different dataset maps:
if old_datasetmap:
pdb, chain, _, res = flat_dataset_map[i]
pdb_chain, chain, _, res = flat_dataset_map[i]
count = 1
else:
pdb, count = flat_dataset_map[i]
pdb_chain, count = flat_dataset_map[i]
count = int(count)
# TODO: this line is not elegant in the way it handles 4 letter codes as PDB codes. It might lead to problems later on
if len(pdb) == 4:
pdbchain = pdb[:4] + "A"
else:
pdbchain = pdb
pdb_chain += chain
# Prepare the dictionaries:
if pdbchain not in pdb_to_sequence:
pdb_to_sequence[pdbchain] = ""
pdb_to_real_sequence[pdbchain] = ""
pdb_to_probability[pdbchain] = []
if pdb_chain not in pdb_to_sequence:
pdb_to_sequence[pdb_chain] = ""
pdb_to_real_sequence[pdb_chain] = ""
pdb_to_probability[pdb_chain] = []
# Loop through map:
for n in range(previous_count, previous_count + count):
if old_datasetmap:
Expand All @@ -688,33 +761,57 @@ def extract_sequence_from_pred_matrix(

pred = list(prediction_matrix[idx])
curr_res = res_dic[max_idx[idx]]
pdb_to_probability[pdbchain].append(pred)
pdb_to_sequence[pdbchain] += curr_res
pdb_to_probability[pdb_chain].append(pred)
pdb_to_sequence[pdb_chain] += curr_res
if old_datasetmap:
pdb_to_real_sequence[pdbchain] += res_to_r_dic[res]
pdb_to_real_sequence[pdb_chain] += res_to_r_dic[res]
if not old_datasetmap:
previous_count += count

if chains_to_customize:
if res_to_fix and not res_to_predict:
customize_fixed_residues(
pdb_to_sequence, pdb_to_real_sequence, chains_to_customize, res_to_fix
)
elif res_to_predict and not res_to_fix:
customize_predicted_residues(
pdb_to_sequence,
pdb_to_real_sequence,
chains_to_customize,
res_to_predict,
)
elif not res_to_fix and not res_to_predict:
warnings.warn(
"No prompt was given to fix or predict residues. TIMED will make predictions for all the residues."
)
else:
warnings.warn(
"Both --res_to_fix and --res_to_predict flags were given. TIMED will make predictions for all the residues."
)

if is_consensus:
last_pdb = ""
# Sum up probabilities:
for pdb in pdb_to_sequence.keys():
curr_pdb = pdb.split("_")[0]
for pdb_chain in pdb_to_sequence.keys():
curr_pdb = pdb_chain.split("_")[0]
if last_pdb != curr_pdb:
pdb_to_consensus_prob[curr_pdb] = np.array(pdb_to_probability[pdb])
pdb_to_consensus_prob[curr_pdb] = np.array(
pdb_to_probability[pdb_chain]
)
last_pdb = curr_pdb
else:
pdb_to_consensus_prob[curr_pdb] = (
pdb_to_consensus_prob[curr_pdb] + np.array(pdb_to_probability[pdb])
pdb_to_consensus_prob[curr_pdb]
+ np.array(pdb_to_probability[pdb_chain])
) / 2
# Extract sequences from consensus probabilities:
for pdb in pdb_to_consensus_prob.keys():
pdb_to_consensus[pdb] = ""
curr_prob = pdb_to_consensus_prob[pdb]
for pdb_chain in pdb_to_consensus_prob.keys():
pdb_to_consensus[pdb_chain] = ""
curr_prob = pdb_to_consensus_prob[pdb_chain]
max_idx = np.argmax(curr_prob, axis=1)
for m in max_idx:
curr_res = res_dic[m]
pdb_to_consensus[pdb] += curr_res
pdb_to_consensus[pdb_chain] += curr_res

return (
pdb_to_sequence,
Expand Down
Loading