Skip to content

Commit

Permalink
Pdb dataset fix (#86)
Browse files Browse the repository at this point in the history
* Fixed pdb dataset download

* Fixed pdb dataset init

* Added change to Changelog

* remove commented lines

* changed pdb configs and structure dir creation
  • Loading branch information
kierandidi authored Mar 20, 2024
1 parent 22d6379 commit 380a738
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
### Datasets
* Add stage-based conditions to `setup` in `ProteinDataModule` [#72](https://github.com/a-r-j/ProteinWorkshop/pull/72)
* Improves support for datamodules with multiple test sets. Generalises this to support GO and FOLD. Also adds multiple seq ID.-based splits for GO. [#72](https://github.com/a-r-j/ProteinWorkshop/pull/72)
* Add redownload checks for already downloaded datasets and harmonise pdb download interface [#86](https://github.com/a-r-j/ProteinWorkshop/pull/86)

### Models

Expand Down
1 change: 1 addition & 0 deletions proteinworkshop/config/dataset/pdb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ datamodule:
batch_size: 32 # Batch size for dataloader
num_workers: 4 # Number of workers for dataloader
pin_memory: True # Pin memory for dataloader
structure_format: "mmtf.gz" # Structure format for files to be downloaded
transforms: ${transforms} # Transforms to apply to dataset examples
overwrite: False # Whether to overwrite existing dataset files

Expand Down
63 changes: 45 additions & 18 deletions proteinworkshop/datasets/pdb_dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import Callable, Iterable, List, Optional
from typing import Callable, Iterable, List, Optional, Dict

import hydra
import omegaconf
import os
import pandas as pd
import pathlib
from graphein.ml.datasets import PDBManager
from loguru import logger as log
from torch_geometric.data import Dataset
Expand All @@ -11,7 +13,6 @@
from proteinworkshop.datasets.base import ProteinDataModule, ProteinDataset
from proteinworkshop.datasets.utils import download_pdb_mmtf


class PDBData:
def __init__(
self,
Expand Down Expand Up @@ -65,10 +66,6 @@ def create_dataset(self):
pdb_manager.molecule_type(self.molecule_type, update=True)
log.info(f"{len(pdb_manager.df)} chains remaining")

# log.info(f"Removing chains experiment types not in selection: {self.experiment_types}...")
# pdb_manager.experiment_type(self.experiment_types, update=True)
# log.info(f"{len(pdb_manager.df)} chains remaining")

log.info(
f"Removing chains oligomeric state not in selection: {self.oligomeric_min} - {self.oligomeric_max}..."
)
Expand Down Expand Up @@ -119,9 +116,6 @@ def create_dataset(self):
splits=split_names,
split_ratios=self.split_sizes,
)
# log.info(splits["train"])
# log.info(pdb_manager.df)
# return pdb_manager.df
log.info(splits["train"])
return splits

Expand All @@ -130,21 +124,31 @@ class PDBDataModule(ProteinDataModule):
def __init__(
self,
path: Optional[str] = None,
structure_dir: Optional[str] = None,
pdb_dataset: Optional[PDBData] = None,
transforms: Optional[Iterable[Callable]] = None,
in_memory: bool = False,
batch_size: int = 32,
num_workers: int = 0,
pin_memory: bool = False,
structure_format: str = "mmtf.gz",
overwrite: bool = False,
):
super().__init__()
self.root = path
self.dataset = pdb_dataset
self.dataset.path = path
self.format = "mmtf.gz"
self.format = structure_format
self.overwrite = overwrite

if structure_dir is not None:
self.structure_dir = pathlib.Path(structure_dir)
else:
self.structure_dir = pathlib.Path(self.root) / "raw"

# Create struture directory if it does not exist already
self.structure_dir.mkdir(parents=True, exist_ok=True)

self.in_memory = in_memory

if transforms is not None:
Expand All @@ -157,19 +161,43 @@ def __init__(
self.num_workers = num_workers
self.pin_memory = pin_memory
self.batch_size = batch_size

def parse_dataset(self) -> pd.DataFrame:
return self.dataset.create_dataset()


def parse_dataset(self) -> Dict[str, pd.DataFrame]:
if hasattr(self, "splits"):
return getattr(self, "splits")

splits = self.dataset.create_dataset()
ids_to_exclude = self.exclude_pdbs()

if ids_to_exclude is not None:
for k, v in splits.items():
log.info(f"Split {k} has {len(v)} chains before excluding failing PDB")
v["id"] = v["pdb"] + "_" + v["chain"].str.join("")
log.info(v)
splits[k] = v.loc[v.id.isin(ids_to_exclude) == False]
log.info(
f"Split {k} has {len(splits[k])} chains after excluding failing PDB"
)
self.splits = splits
breakpoint()
return splits

def exclude_pdbs(self):
pass

def download(self):
pdbs = self.parse_dataset()

for k, v in pdbs:
log.info(f"Downloading {k} PDBs")
download_pdb_mmtf(pathlib.Path(self.root) / "raw", v.pdb.tolist())
for k, v in pdbs.items():
log.info(f"Downloading {k} PDBs to {self.structure_dir}")
pdblist = v.pdb.tolist()
pdblist = [
pdb
for pdb in pdblist
if not os.path.exists(self.structure_dir / f"{pdb}.{self.format}")
]
download_pdb_mmtf(self.structure_dir, pdblist)

def parse_labels(self):
raise NotImplementedError
Expand Down Expand Up @@ -223,7 +251,6 @@ def test_dataset(self) -> Dataset:


if __name__ == "__main__":
import pathlib

from proteinworkshop import constants

Expand All @@ -234,4 +261,4 @@ def test_dataset(self) -> Dataset:
print(cfg)
ds = hydra.utils.instantiate(cfg)["datamodule"]
print(ds)
ds.val_dataset()
ds.val_dataset()

0 comments on commit 380a738

Please sign in to comment.