Skip to content

Commit

Permalink
Add full MSA pairing support (#275)
Browse files Browse the repository at this point in the history
* Update data_pipeline.py

* Update msa_parsing.py

* Update inputs.py
  • Loading branch information
amorehead authored Sep 21, 2024
1 parent 623b4d0 commit 2619f4b
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 37 deletions.
7 changes: 6 additions & 1 deletion alphafold3_pytorch/data/data_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def make_msa_features(
msas: Dict[str, msa_parsing.Msa],
chain_id_to_residue: Dict[str, Dict[str, List[int]]],
num_msa_one_hot: int,
tab_separated_alignment_headers: bool = False,
ligand_chemtype_index: int = 3,
) -> List[Dict[str, np.ndarray]]:
"""
Expand All @@ -122,6 +123,7 @@ def make_msa_features(
:param msas: The mapping of chain IDs to lists of MSAs for each chain.
:param chain_id_to_residue: The mapping of chain IDs to residue information.
:param num_msa_one_hot: The number of one-hot classes for MSA features.
:param tab_separated_alignment_headers: Whether the alignment headers are tab-separated.
:param ligand_chemtype_index: The index of the ligand in the chemical type list.
:return: The MSA chain feature dictionaries.
"""
Expand Down Expand Up @@ -215,7 +217,10 @@ def make_msa_features(
deletion_matrix.append(msa_deletion_values)

# Parse species ID for MSA pairing if possible.
species_id = msa_parsing.get_identifiers(msa.descriptions[sequence_index]).species_id
species_id = msa_parsing.get_identifiers(
description=msa.descriptions[sequence_index],
tab_separated_alignment_headers=tab_separated_alignment_headers,
).species_id

if sequence_index == 0:
species_id = "-1" # Tag target sequence for filtering.
Expand Down
75 changes: 39 additions & 36 deletions alphafold3_pytorch/data/msa_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,20 +69,34 @@ def get_msa_type(msa_chem_type: int) -> MSA_TYPE:
raise ValueError(f"Invalid MSA chemical type: {msa_chem_type}")


@typecheck
def _parse_species_identifier(description: str) -> Identifiers:
"""Gets species from an MSA sequence identifier.
The sequence identifier in this instance has a tab-separated format,
except for the query identifier which is not linked to a species.
:param description: a sequence identifier.
:return: An `Identifiers` instance with species_id. These
can be empty in the case where no identifier was found.
"""
split_description = description.split("\t")
if len(split_description) > 1:
return Identifiers(species_id=split_description[-1].strip())
return Identifiers()


@typecheck
def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
"""Gets species from an msa sequence identifier.
"""Gets species from an MSA sequence identifier.
The sequence identifier has the format specified by
_UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN.
An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE`
Args:
msa_sequence_identifier: a sequence identifier.
Returns:
An `Identifiers` instance with species_id. These
can be empty in the case where no identifier was found.
:param msa_sequence_identifier: a sequence identifier.
:return: An `Identifiers` instance with species_id. These
can be empty in the case where no identifier was found.
"""
matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip())
if matches:
Expand All @@ -94,7 +108,8 @@ def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers:
def _extract_sequence_identifier(description: str) -> Optional[str]:
"""Extracts sequence identifier from description.
Returns None if no match.
:param description: a sequence description.
:return: The sequence identifier.
"""
split_description = description.split()
if split_description:
Expand All @@ -104,36 +119,24 @@ def _extract_sequence_identifier(description: str) -> Optional[str]:


@typecheck
def _extract_sequence_accession_id(description: str) -> Optional[str]:
"""Extracts sequence identifier from description.
Returns None if no match.
def get_identifiers(
description: str, tab_separated_alignment_headers: bool = False
) -> Identifiers:
"""Computes extra MSA features from the description.
:param description: The description of the sequence.
:param tab_separated_alignment_headers: Whether the alignment headers are tab-separated.
:return: An `Identifiers` instance with species_id. These can be empty in the case
where no identifier was found.
"""
split_description = description.split()
if split_description:
return split_description[0].split(">")[-1]
else:
return None


@typecheck
def get_identifiers(description: str) -> Identifiers:
"""Computes extra MSA features from the description."""
sequence_identifier = _extract_sequence_identifier(description)
if not_exists(sequence_identifier):
return Identifiers()
else:
return _parse_sequence_identifier(sequence_identifier)


@typecheck
def get_accession_id(description: str) -> str:
"""Computes extra MSA features from the description."""
sequence_accession_id = _extract_sequence_accession_id(description)
if not_exists(sequence_accession_id):
return ""
if tab_separated_alignment_headers:
return _parse_species_identifier(description)
else:
return sequence_accession_id
sequence_identifier = _extract_sequence_identifier(description)
if not_exists(sequence_identifier):
return Identifiers()
else:
return _parse_sequence_identifier(sequence_identifier)


@dataclasses.dataclass(frozen=True)
Expand Down
1 change: 1 addition & 0 deletions alphafold3_pytorch/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -2776,6 +2776,7 @@ def load_msa_from_msa_dir(
msas,
chain_id_to_residue,
num_msa_one_hot=NUM_MSA_ONE_HOT,
tab_separated_alignment_headers=not distillation,
)
unique_entity_ids = set(chain["entity_id"][0] for chain in chains)

Expand Down

0 comments on commit 2619f4b

Please sign in to comment.