Skip to content

Commit

Permalink
added sequence clustering to pdb dataset (#88)
Browse files Browse the repository at this point in the history
* added sequence clustering to pdb dataset

* update CHANGELOG

* changed overwrite setting

* change split_ratios to train_val_test

* add overwrite cluster flag

* changed pdbmanager arg
  • Loading branch information
kierandidi authored Mar 22, 2024
1 parent 716a1b2 commit d5fbab7
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 12 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* Improves support for datamodules with multiple test sets. Generalises this to support GO and FOLD. Also adds multiple seq ID.-based splits for GO. [#72](https://github.com/a-r-j/ProteinWorkshop/pull/72)
* Add redownload checks for already downloaded datasets and harmonise pdb download interface [#86](https://github.com/a-r-j/ProteinWorkshop/pull/86)
* Remove remaining errors from PDB dataset change
* Add option to create pdb datasets with sequence-based splits [#88](https://github.com/a-r-j/ProteinWorkshop/pull/88)

### Models

Expand Down
5 changes: 4 additions & 1 deletion proteinworkshop/config/dataset/pdb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,7 @@ datamodule:
remove_ligands: [] # Exclude specific ligands from any available protein-ligand complexes
remove_non_standard_residues: True # Include only proteins containing standard amino acid residues
remove_pdb_unavailable: True # Include only proteins that are available to download
split_sizes: [0.8, 0.1, 0.1] # Cross-validation ratios to use for train, val, and test splits
train_val_test: [0.8, 0.1, 0.1] # Cross-validation ratios to use for train, val, and test splits
split_type: "sequence_similarity" # Split sequences by sequence similarity clustering, other option is "random"
split_sequence_similiarity: 0.3 # Clustering at 30% sequence similarity (argument is ignored if split_type="random")
overwrite_sequence_clusters: False # Previous clusterings at same sequence similarity are reused and not overwritten
38 changes: 27 additions & 11 deletions proteinworkshop/datasets/pdb_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Iterable, List, Optional, Dict
from typing import Callable, Iterable, List, Optional, Dict, Literal

import hydra
import omegaconf
Expand Down Expand Up @@ -29,7 +29,10 @@ def __init__(
remove_ligands: List[str],
remove_non_standard_residues: bool,
remove_pdb_unavailable: bool,
split_sizes: List[float],
train_val_test: List[float],
split_type: Literal["sequence_similarity", "random"],
split_sequence_similiarity: int,
overwrite_sequence_clusters: bool
):
self.fraction = fraction
self.molecule_type = molecule_type
Expand All @@ -44,11 +47,16 @@ def __init__(
self.remove_pdb_unavailable = remove_pdb_unavailable
self.min_length = min_length
self.max_length = max_length
self.split_sizes = split_sizes
assert sum(train_val_test) == 1, f"train_val_test need to sum to 1, but sum to {sum(train_val_test)}"
self.train_val_test = train_val_test
self.split_type = split_type
self.split_sequence_similarity = split_sequence_similiarity
self.overwrite_sequence_clusters = overwrite_sequence_clusters
self.splits = ["train", "val", "test"]

def create_dataset(self):
log.info(f"Initializing PDBManager in {self.path}...")
pdb_manager = PDBManager(root_dir=self.path)
pdb_manager = PDBManager(root_dir=self.path, splits=self.splits, split_ratios=self.train_val_test)
num_chains = len(pdb_manager.df)
log.info(f"Starting with: {num_chains} chains")

Expand Down Expand Up @@ -109,13 +117,21 @@ def create_dataset(self):
pdb_manager.remove_unavailable_pdbs(update=True)
log.info(f"{len(pdb_manager.df)} chains remaining")

log.info(f"Splitting dataset into {self.split_sizes}...")
split_names = ["train", "val", "test"]
splits = pdb_manager.split_df_proportionally(
df=pdb_manager.df,
splits=split_names,
split_ratios=self.split_sizes,
)
if self.split_type == "random":
log.info(f"Splitting dataset via random split into {self.train_val_test}...")
splits = pdb_manager.split_df_proportionally(
df=pdb_manager.df,
splits=self.splits,
train_val_test=self.train_val_test,
)

elif self.split_type == "sequence_similarity":
log.info(f"Splitting dataset via sequence-similarity split into {self.train_val_test}...")
log.info(f"Using {self.split_sequence_similarity} sequence similarity for split")
pdb_manager.cluster(min_seq_id=self.split_sequence_similarity, update=True)
splits = pdb_manager.split_clusters(
pdb_manager.df, update=True, overwrite = self.overwrite_sequence_clusters)

log.info(splits["train"])
return splits

Expand Down

0 comments on commit d5fbab7

Please sign in to comment.