From 2619f4b10e59e6797ad336ee17009ad90f9ef9c7 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Fri, 20 Sep 2024 19:30:29 -0500 Subject: [PATCH] Add full MSA pairing support (#275) * Update data_pipeline.py * Update msa_parsing.py * Update inputs.py --- alphafold3_pytorch/data/data_pipeline.py | 7 ++- alphafold3_pytorch/data/msa_parsing.py | 75 ++++++++++++------------ alphafold3_pytorch/inputs.py | 1 + 3 files changed, 46 insertions(+), 37 deletions(-) diff --git a/alphafold3_pytorch/data/data_pipeline.py b/alphafold3_pytorch/data/data_pipeline.py index 4a53a322..5340fcea 100644 --- a/alphafold3_pytorch/data/data_pipeline.py +++ b/alphafold3_pytorch/data/data_pipeline.py @@ -113,6 +113,7 @@ def make_msa_features( msas: Dict[str, msa_parsing.Msa], chain_id_to_residue: Dict[str, Dict[str, List[int]]], num_msa_one_hot: int, + tab_separated_alignment_headers: bool = False, ligand_chemtype_index: int = 3, ) -> List[Dict[str, np.ndarray]]: """ @@ -122,6 +123,7 @@ def make_msa_features( :param msas: The mapping of chain IDs to lists of MSAs for each chain. :param chain_id_to_residue: The mapping of chain IDs to residue information. :param num_msa_one_hot: The number of one-hot classes for MSA features. + :param tab_separated_alignment_headers: Whether the alignment headers are tab-separated. :param ligand_chemtype_index: The index of the ligand in the chemical type list. :return: The MSA chain feature dictionaries. """ @@ -215,7 +217,10 @@ def make_msa_features( deletion_matrix.append(msa_deletion_values) # Parse species ID for MSA pairing if possible. - species_id = msa_parsing.get_identifiers(msa.descriptions[sequence_index]).species_id + species_id = msa_parsing.get_identifiers( + description=msa.descriptions[sequence_index], + tab_separated_alignment_headers=tab_separated_alignment_headers, + ).species_id if sequence_index == 0: species_id = "-1" # Tag target sequence for filtering. diff --git a/alphafold3_pytorch/data/msa_parsing.py b/alphafold3_pytorch/data/msa_parsing.py index 1cf285ec..c1ac6fc3 100644 --- a/alphafold3_pytorch/data/msa_parsing.py +++ b/alphafold3_pytorch/data/msa_parsing.py @@ -69,20 +69,34 @@ def get_msa_type(msa_chem_type: int) -> MSA_TYPE: raise ValueError(f"Invalid MSA chemical type: {msa_chem_type}") +@typecheck +def _parse_species_identifier(description: str) -> Identifiers: + """Gets species from an MSA sequence identifier. + + The sequence identifier in this instance has a tab-separated format, + except for the query identifier which is not linked to a species. + + :param description: a sequence identifier. + :return: An `Identifiers` instance with species_id. These + can be empty in the case where no identifier was found. + """ + split_description = description.split("\t") + if len(split_description) > 1: + return Identifiers(species_id=split_description[-1].strip()) + return Identifiers() + + @typecheck def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers: - """Gets species from an msa sequence identifier. + """Gets species from an MSA sequence identifier. The sequence identifier has the format specified by _UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN. An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE` - Args: - msa_sequence_identifier: a sequence identifier. - - Returns: - An `Identifiers` instance with species_id. These - can be empty in the case where no identifier was found. + :param msa_sequence_identifier: a sequence identifier. + :return: An `Identifiers` instance with species_id. These + can be empty in the case where no identifier was found. """ matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip()) if matches: @@ -94,7 +108,8 @@ def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers: def _extract_sequence_identifier(description: str) -> Optional[str]: """Extracts sequence identifier from description. - Returns None if no match. + :param description: a sequence description. + :return: The sequence identifier. """ split_description = description.split() if split_description: @@ -104,36 +119,24 @@ def _extract_sequence_identifier(description: str) -> Optional[str]: @typecheck -def _extract_sequence_accession_id(description: str) -> Optional[str]: - """Extracts sequence identifier from description. - - Returns None if no match. +def get_identifiers( + description: str, tab_separated_alignment_headers: bool = False +) -> Identifiers: + """Computes extra MSA features from the description. + + :param description: The description of the sequence. + :param tab_separated_alignment_headers: Whether the alignment headers are tab-separated. + :return: An `Identifiers` instance with species_id. These can be empty in the case + where no identifier was found. """ - split_description = description.split() - if split_description: - return split_description[0].split(">")[-1] - else: - return None - - -@typecheck -def get_identifiers(description: str) -> Identifiers: - """Computes extra MSA features from the description.""" - sequence_identifier = _extract_sequence_identifier(description) - if not_exists(sequence_identifier): - return Identifiers() - else: - return _parse_sequence_identifier(sequence_identifier) - - -@typecheck -def get_accession_id(description: str) -> str: - """Computes extra MSA features from the description.""" - sequence_accession_id = _extract_sequence_accession_id(description) - if not_exists(sequence_accession_id): - return "" + if tab_separated_alignment_headers: + return _parse_species_identifier(description) else: - return sequence_accession_id + sequence_identifier = _extract_sequence_identifier(description) + if not_exists(sequence_identifier): + return Identifiers() + else: + return _parse_sequence_identifier(sequence_identifier) @dataclasses.dataclass(frozen=True) diff --git a/alphafold3_pytorch/inputs.py b/alphafold3_pytorch/inputs.py index 4b23f6ae..767f2b55 100644 --- a/alphafold3_pytorch/inputs.py +++ b/alphafold3_pytorch/inputs.py @@ -2776,6 +2776,7 @@ def load_msa_from_msa_dir( msas, chain_id_to_residue, num_msa_one_hot=NUM_MSA_ONE_HOT, + tab_separated_alignment_headers=not distillation, ) unique_entity_ids = set(chain["entity_id"][0] for chain in chains)