Skip to content

Commit

Permalink
add overwrite cluster flag
Browse files Browse the repository at this point in the history
  • Loading branch information
kierandidi committed Mar 22, 2024
1 parent aea2b4c commit 361eb5d
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
1 change: 1 addition & 0 deletions proteinworkshop/config/dataset/pdb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ datamodule:
train_val_test: [0.8, 0.1, 0.1] # Cross-validation ratios to use for train, val, and test splits
split_type: "sequence_similarity" # Split sequences by sequence similarity clustering, other option is "random"
split_sequence_similiarity: 0.3 # Clustering at 30% sequence similarity (argument is ignored if split_type="random")
overwrite_sequence_clusters: False # Previous clusterings at same sequence similarity are reused and not overwritten
7 changes: 5 additions & 2 deletions proteinworkshop/datasets/pdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def __init__(
remove_pdb_unavailable: bool,
train_val_test: List[float],
split_type: Literal["sequence_similarity", "random"],
split_sequence_similiarity: int
split_sequence_similiarity: int,
overwrite_sequence_clusters: bool
):
self.fraction = fraction
self.molecule_type = molecule_type
Expand All @@ -50,6 +51,7 @@ def __init__(
self.train_val_test = train_val_test
self.split_type = split_type
self.split_sequence_similarity = split_sequence_similiarity
self.overwrite_sequence_clusters = overwrite_sequence_clusters
self.splits = ["train", "val", "test"]

def create_dataset(self):
Expand Down Expand Up @@ -127,7 +129,8 @@ def create_dataset(self):
log.info(f"Splitting dataset via sequence-similarity split into {self.train_val_test}...")
log.info(f"Using {self.split_sequence_similarity} sequence similarity for split")
pdb_manager.cluster(min_seq_id=self.split_sequence_similarity, update=True)
splits = pdb_manager.split_clusters(pdb_manager.df, update=True)
splits = pdb_manager.split_clusters(
pdb_manager.df, update=True, overwrite = self.overwrite_sequence_clusters)

log.info(splits["train"])
return splits
Expand Down

0 comments on commit 361eb5d

Please sign in to comment.