diff --git a/src/pp5/contacts.py b/src/pp5/contacts.py index faa2693..2cf162e 100644 --- a/src/pp5/contacts.py +++ b/src/pp5/contacts.py @@ -274,7 +274,7 @@ def _resolve_atom_contacts( continue # Ignore any contacts with water - if tgt_resname == "HOH": + if tgt_resname == "HOH": # TODO: also check for hydrogen 1HJE continue src_chain = src_res.get_parent() @@ -286,23 +286,23 @@ def _resolve_atom_contacts( contact_dist = src_atom - tgt_atom contact_dists.append(contact_dist) - # Key uniquely identifying the contact - contact_key = (tgt_chain_id, tgt_resname, tgt_seq_idx, tgt_altloc) + # Key uniquely identifying the contact target + contact_tgt_key = (tgt_chain_id, tgt_resname, tgt_seq_idx, tgt_altloc) # Check if contact is a ligand (check hetero flag) if tgt_hetflag.startswith("H_"): - res_contacts_non_aa.setdefault(contact_key, []) - res_contacts_non_aa[contact_key].append(contact_dist) + res_contacts_non_aa.setdefault(contact_tgt_key, []) + res_contacts_non_aa[contact_tgt_key].append(contact_dist) # Check if contact is out of chain elif src_chain != tgt_chain: - res_contacts_ooc.setdefault(contact_key, []) - res_contacts_ooc[contact_key].append(contact_dist) + res_contacts_ooc.setdefault(contact_tgt_key, []) + res_contacts_ooc[contact_tgt_key].append(contact_dist) # Regular AA contact in current chain else: - res_contacts_aas.setdefault(contact_key, []) - res_contacts_aas[contact_key].append(contact_dist) + res_contacts_aas.setdefault(contact_tgt_key, []) + res_contacts_aas[contact_tgt_key].append(contact_dist) # Calculate sequence distance (only in-chain) sequence_dists.append(abs(tgt_seq_idx - src_seq_idx)) @@ -315,37 +315,37 @@ def _resolve_atom_contacts( if sequence_dists: contact_smin, contact_smax = min(sequence_dists), max(sequence_dists) - def _format_unique_contacts( - _contacts: Dict[tuple, List[float]] - ) -> Sequence[str]: + def _aggregate(_contacts: Dict[tuple, List[float]]) -> Dict[tuple, float]: """ - Formats multiple contacts as a list of strings, by merging the contacts - which have a unique key. For each unique contact key, calculates the - average distance, and uses that to format the contact. - :param _contacts: The contacts to format. - :return: A list of formatted contacts. + For each unique contact target key, aggregates the distances by taking + the minimum. + + :param _contacts: The contacts to format: {tgt_key: [dist1, dist2, ...]} + :return: The aggregated contacts: {tgt_key: min_dist} """ _contacts_to_dist: Dict[tuple, float] = {} - _formatted_contacts = [] + for _tgt_key, _dists in _contacts.items(): + _contacts_to_dist[_tgt_key] = np.min(_dists) + return _contacts_to_dist - # Aggregate distances - for _key, _dists in _contacts.items(): - _mean_dist = np.mean(_dists) - _contacts_to_dist[_key] = _mean_dist + def _format(_contacts: Dict[tuple, float]) -> Sequence[str]: + """ + Formats contacts as a list of strings, by merging the contacts + which have a unique target key. + :param _contacts: The contacts to format: {tgt_key: dist} + :return: A list of formatted contacts. + """ # Sort by distance - _contacts_to_dist = dict( - sorted(_contacts_to_dist.items(), key=lambda x: x[1]) - ) - - for _key, _dist in _contacts_to_dist.items(): - _tgt_chain_id, _tgt_resname, _tgt_seq_idx, _tgt_altloc = _key + _contacts = dict(sorted(_contacts.items(), key=lambda x: x[1])) + _formatted_contacts = [] + for _tgt_key, _dist in _contacts.items(): + _tgt_chain_id, _tgt_resname, _tgt_seq_idx, _tgt_altloc = _tgt_key _formatted_contacts.append( format_residue_contact( _tgt_chain_id, _tgt_resname, _tgt_seq_idx, _tgt_altloc, _dist ) ) - return tuple(_formatted_contacts) return ResidueContacts( @@ -353,9 +353,9 @@ def _format_unique_contacts( contact_count=contact_count, contact_types="proximal", # use arpeggio name, but not meaningful here contact_smax=contact_smax, - contact_ooc=_format_unique_contacts(res_contacts_ooc), - contact_non_aa=_format_unique_contacts(res_contacts_non_aa), - contact_aas=_format_unique_contacts(res_contacts_aas), + contact_ooc=_format(_aggregate(res_contacts_ooc)), + contact_non_aa=_format(_aggregate(res_contacts_non_aa)), + contact_aas=_format(_aggregate(res_contacts_aas)), )