Skip to content

Commit

Permalink
set up ruff pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
nebfield committed Jan 15, 2024
1 parent 1e33d24 commit 605cfd2
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 94 deletions.
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.13
hooks:
- id: ruff
args: [ --fix ]
- id: ruff-format
3 changes: 1 addition & 2 deletions pgscatalog.calclib/src/pgscatalog/calclib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from pgscatalog.calclib.testclass import TestClass

# be explicit about public interfaces
__all__ = ["TestClass"
]
__all__ = ["TestClass"]
2 changes: 1 addition & 1 deletion pgscatalog.calclib/src/pgscatalog/calclib/testclass.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
class TestClass:
pass
pass
2 changes: 1 addition & 1 deletion pgscatalog.corelib/src/pgscatalog/corelib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pgscatalog.corelib.scorefiles import ScoringFile, GenomeBuild, ScoringFiles
from pgscatalog.corelib import config

__all__ = ["ScoringFile", "ScoringFiles", "GenomeBuild", "config"]
__all__ = ["ScoringFile", "ScoringFiles", "GenomeBuild", "config"]
67 changes: 41 additions & 26 deletions pgscatalog.corelib/src/pgscatalog/corelib/catalogapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


class CatalogCategory(enum.Enum):
""" The main categories in the PGS Catalog. Enumeration values don't mean anything.
"""The main categories in the PGS Catalog. Enumeration values don't mean anything.
>>> CatalogCategory.SCORE
<CatalogCategory.SCORE: 1>
Expand All @@ -20,7 +20,7 @@ class CatalogCategory(enum.Enum):


class CatalogQuery:
""" Efficiently batch query the PGS Catalog API using trait (EFO), score (PGS ID),
"""Efficiently batch query the PGS Catalog API using trait (EFO), score (PGS ID),
or publication identifier (PGP ID).
>>> CatalogQuery(accession="PGS000001")
Expand All @@ -40,6 +40,7 @@ class CatalogQuery:
>>> CatalogQuery(accession="EFO_0001645")
CatalogQuery(accession='EFO_0001645', category=CatalogCategory.TRAIT, include_children=False)
"""

_rest_url_root = "https://www.pgscatalog.org/rest"

def __init__(self, *, accession, include_children=False, **kwargs):
Expand All @@ -61,11 +62,13 @@ def __init__(self, *, accession, include_children=False, **kwargs):
raise ValueError(f"Bad category: {self.category!r}")

def __repr__(self):
return (f"{type(self).__name__}(accession={repr(self.accession)}, category="
f"{self.category}, include_children={self.include_children})")
return (
f"{type(self).__name__}(accession={repr(self.accession)}, category="
f"{self.category}, include_children={self.include_children})"
)

def infer_category(self) -> CatalogCategory:
""" Inspect an accession and guess the Catalog category
"""Inspect an accession and guess the Catalog category
Assume lists of accessions only contain PGS IDs:
>>> CatalogQuery(accession=["PGS000001", "PGS000002"]).infer_category()
Expand All @@ -87,12 +90,14 @@ def infer_category(self) -> CatalogCategory:
case list() if all([x.startswith("PGS") for x in accession]):
category = CatalogCategory.SCORE
case list():
raise ValueError(f"Invalid accession in list: {accession!r}. Lists must only contain PGS IDs.")
raise ValueError(
f"Invalid accession in list: {accession!r}. Lists must only contain PGS IDs."
)
case str() if accession.startswith("PGS"):
category = CatalogCategory.SCORE
case str() if accession.startswith("PGP"):
category = CatalogCategory.PUBLICATION
case str() if '_' in accession:
case str() if "_" in accession:
# simple check for structured text like EFO_ACCESSION, HP_ACCESSION, etc
category = CatalogCategory.TRAIT
case _:
Expand Down Expand Up @@ -138,25 +143,26 @@ def get_query_url(self):
chunked_accession = ",".join(chunk)
urls.append(
f"{self._rest_url_root}/score/search?pgs_ids="
f"{chunked_accession}")
f"{chunked_accession}"
)
return urls
case CatalogCategory.PUBLICATION, str():
return f"{self._rest_url_root}/publication/{self.accession}"
case _:
raise ValueError(
f"Invalid CatalogCategory and accession type: {self.category!r}, "
f"type({self.accession!r})")
f"type({self.accession!r})"
)

def _chunk_accessions(self):
size = 50 # /rest/score/{pgs_id} limit when searching multiple IDs
# using a dict to get unique elements instead of a set to preserve order
accessions = self.accession
return (accessions[pos: pos + size] for pos in
range(0, len(accessions), size))
return (accessions[pos : pos + size] for pos in range(0, len(accessions), size))

@retry(stop=stop_after_attempt(5))
def score_query(self):
""" Query the PGS Catalog API and return ScoreQueryResult
"""Query the PGS Catalog API and return ScoreQueryResult
Information about a single score is returned as a dict:
>>> CatalogQuery(accession="PGS000001").score_query() # doctest: +ELLIPSIS
Expand All @@ -176,7 +182,7 @@ def score_query(self):

for url in self.get_query_url():
r = httpx.get(url, timeout=5, headers=config.API_HEADER).json()
results += r['results']
results += r["results"]

# return the same type as the accession input to be consistent
match self.accession:
Expand All @@ -190,33 +196,37 @@ def score_query(self):
case CatalogCategory.PUBLICATION:
url = self.get_query_url()
r = httpx.get(url, timeout=5, headers=config.API_HEADER).json()
pgs_ids = [score for scores in list(r["associated_pgs_ids"].values())
for score in scores]
pgs_ids = [
score
for scores in list(r["associated_pgs_ids"].values())
for score in scores
]
return CatalogQuery(accession=pgs_ids).score_query()
case CatalogCategory.TRAIT:
url = self.get_query_url()
r = httpx.get(url, timeout=5, headers=config.API_HEADER).json()
pgs_ids = r['associated_pgs_ids']
pgs_ids = r["associated_pgs_ids"]
if self.include_children:
pgs_ids.extend(r['child_associated_pgs_ids'])
pgs_ids.extend(r["child_associated_pgs_ids"])
return CatalogQuery(accession=pgs_ids).score_query()


class ScoreQueryResult:
""" Class that holds score metadata with methods to extract important fields """
def __init__(self, *,
pgs_id, ftp_url, ftp_grch37_url, ftp_grch38_url, license
):
"""Class that holds score metadata with methods to extract important fields"""

def __init__(self, *, pgs_id, ftp_url, ftp_grch37_url, ftp_grch38_url, license):
self.pgs_id = pgs_id
self.ftp_url = ftp_url
self.ftp_grch37_url = ftp_grch37_url
self.ftp_grch38_url = ftp_grch38_url
self.license = license

def __repr__(self):
return (f"{type(self).__name__}(pgs_id={self.pgs_id!r}, ftp_url={self.ftp_url!r}, "
f"ftp_grch37_url={self.ftp_grch37_url!r},ftp_grch38_url={self.ftp_grch38_url!r},"
f"license={self.license!r})")
return (
f"{type(self).__name__}(pgs_id={self.pgs_id!r}, ftp_url={self.ftp_url!r}, "
f"ftp_grch37_url={self.ftp_grch37_url!r},ftp_grch38_url={self.ftp_grch38_url!r},"
f"license={self.license!r})"
)

@classmethod
def from_query(cls, result_response):
Expand Down Expand Up @@ -244,8 +254,13 @@ def from_query(cls, result_response):
ftp_grch37_url = harmonised_urls["GRCh37"]["positions"]
ftp_grch38_url = harmonised_urls["GRCh38"]["positions"]
license = result_response["license"]
return cls(pgs_id=pgs_id, ftp_url=ftp_url, ftp_grch37_url=ftp_grch37_url, ftp_grch38_url=ftp_grch38_url,
license=license)
return cls(
pgs_id=pgs_id,
ftp_url=ftp_url,
ftp_grch37_url=ftp_grch37_url,
ftp_grch38_url=ftp_grch38_url,
license=license,
)

def get_download_url(self, genome_build=None):
"""
Expand Down
2 changes: 1 addition & 1 deletion pgscatalog.corelib/src/pgscatalog/corelib/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
_package_version = f"{importlib.metadata.version('pgscatalog.corelib')}"
_package_string = f"pgscatalog.corelib/{_package_version}"

API_HEADER = {"user-agent": _package_string}
API_HEADER = {"user-agent": _package_string}
78 changes: 49 additions & 29 deletions pgscatalog.corelib/src/pgscatalog/corelib/scorefiles.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
""" This module contains classes that compose a ScoringFile: a file in the
PGS Catalog that contains a list of genetic variants and their effect weights.
Scoring files are used to calculate PGS for new target genomes. """
import collections
import gzip
import hashlib
import itertools
Expand Down Expand Up @@ -37,6 +36,7 @@ class ScoringFileHeader:
>>> header.genome_build
GenomeBuild.GRCh37
"""

# slots are used here because we want a controlled vocabulary
# random extra attributes would be bad without thinking about them
__slots__ = (
Expand Down Expand Up @@ -65,22 +65,22 @@ class ScoringFileHeader:
)

def __init__(
self,
*,
pgs_name,
genome_build,
pgs_id=None,
pgp_id=None,
variants_number=None,
trait_reported=None,
trait_efo=None,
trait_mapped=None,
weight_type=None,
citation=None,
HmPOS_build=None,
HmPOS_date=None,
format_version=None,
license=None,
self,
*,
pgs_name,
genome_build,
pgs_id=None,
pgp_id=None,
variants_number=None,
trait_reported=None,
trait_efo=None,
trait_mapped=None,
weight_type=None,
citation=None,
HmPOS_build=None,
HmPOS_date=None,
format_version=None,
license=None,
):
"""kwargs are forced because this is a complicated init and from_path() is
almost always the correct thing to do.
Expand Down Expand Up @@ -143,12 +143,13 @@ def from_path(cls, path):


class ScoringFile:
""" Represents a single scoring file.
"""Represents a single scoring file.
Can also be constructed with a ScoreQueryResult to avoid hitting the API during instantiation
"""

def __init__(self, identifier, target_build=None, query_result=None):
if query_result is None:
self._identifier = identifier
Expand Down Expand Up @@ -193,15 +194,20 @@ def _init_from_accession(self, accession, target_build):
pass # just a normal ScoreQueryResult, continue
else:
# this class can only instantiate and represent one scoring file
raise ValueError(f"Can't create a ScoringFile with accession: {accession!r}. "
"Only PGS ids are supported. Try ScoringFiles()")
raise ValueError(
f"Can't create a ScoringFile with accession: {accession!r}. "
"Only PGS ids are supported. Try ScoringFiles()"
)

self.pgs_id = score.pgs_id
self.header = None
self.catalog_response = score
self.path = score.get_download_url(target_build)

@retry(stop=tenacity.stop_after_attempt(5), retry=tenacity.retry_if_exception_type(httpx.RequestError))
@retry(
stop=tenacity.stop_after_attempt(5),
retry=tenacity.retry_if_exception_type(httpx.RequestError),
)
def download(self, directory, overwrite=False):
"""
Download a ScoringFile to a specified directory with checksum validation
Expand Down Expand Up @@ -238,15 +244,17 @@ def download(self, directory, overwrite=False):

if (calc := md5.hexdigest()) != (remote := checksum.split()[0]):
# will attempt to download again (see decorator)
raise httpx.RequestError(f"Calculated checksum {calc} doesn't match {remote}")
raise httpx.RequestError(
f"Calculated checksum {calc} doesn't match {remote}"
)
else:
os.rename(f.name, out_path)
except httpx.UnsupportedProtocol:
raise ValueError(f"Can't download a local file: {self.path!r}")


class ScoringFiles:
""" This class provides methods to work with multiple ScoringFile objects.
"""This class provides methods to work with multiple ScoringFile objects.
You can use publications or trait accessions to instantiate:
>>> pub = ScoringFiles("PGP000001")
Expand Down Expand Up @@ -280,9 +288,14 @@ class ScoringFiles:
ScoringFile('PGS000002')
ScoringFile('PGS000003')
"""

def __init__(self, *args, target_build=None):
# flatten args to provide a more flexible interface
flargs = list(itertools.chain.from_iterable(arg if isinstance(arg, list) else [arg] for arg in args))
flargs = list(
itertools.chain.from_iterable(
arg if isinstance(arg, list) else [arg] for arg in args
)
)
scorefiles = []
pgs_batch = []
for arg in flargs:
Expand All @@ -291,7 +304,12 @@ def __init__(self, *args, target_build=None):
raise NotImplementedError
case str() if arg.startswith("PGP") or "_" in arg:
pgp_scorefiles = CatalogQuery(accession=arg).score_query()
scorefiles.extend([ScoringFile(x.pgs_id, target_build=target_build) for x in pgp_scorefiles])
scorefiles.extend(
[
ScoringFile(x.pgs_id, target_build=target_build)
for x in pgp_scorefiles
]
)
case str() if arg.startswith("PGS"):
pgs_batch.append(arg)
case str():
Expand All @@ -300,7 +318,9 @@ def __init__(self, *args, target_build=None):
raise TypeError

# build scorefiles from a batch query of PGS IDs to avoid smashing the API
batched_queries = CatalogQuery(accession=pgs_batch, target_build=target_build).score_query()
batched_queries = CatalogQuery(
accession=pgs_batch, target_build=target_build
).score_query()
batched_scores = [ScoringFile(x) for x in batched_queries]
scorefiles.extend(batched_scores)

Expand All @@ -324,9 +344,9 @@ def elements(self):
return self._elements

def combine(self):
""" Combining multiple scoring files yields ScoreVariants in a consistent genome build and data format.
"""Combining multiple scoring files yields ScoreVariants in a consistent genome build and data format.
This process takes care of data munging and some quality control steps. """
This process takes care of data munging and some quality control steps."""
raise NotImplementedError


Expand Down Expand Up @@ -365,4 +385,4 @@ def auto_open(filepath, mode="rt"):
if test_f.read(2) == b"\x1f\x8b":
return gzip.open(filepath, mode)
else:
return open(filepath, mode)
return open(filepath, mode)
Loading

0 comments on commit 605cfd2

Please sign in to comment.