diff --git a/pyproject.toml b/pyproject.toml index 10c3fba..ca26517 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,7 @@ markers = "fuzzing: Run Hypothesis fuzz test suite" line_length = 100 force_grid_wrap = 0 include_trailing_comma = true -known_third_party = ["click", "dataclassy", "github", "hypothesis", "hypothesis_jsonschema", "pytest", "requests", "semantic_version", "setuptools"] +known_third_party = ["click", "github", "hypothesis", "hypothesis_jsonschema", "pydantic", "pytest", "requests", "semantic_version", "setuptools"] known_first_party = ["tokenlists"] multi_line_output = 3 use_parentheses = true diff --git a/setup.py b/setup.py index 7bd6bc8..dafae4f 100644 --- a/setup.py +++ b/setup.py @@ -63,7 +63,7 @@ install_requires=[ "importlib-metadata ; python_version<'3.8'", "click>=8.0.0", - "dataclassy>=0.10.3,<1.0", + "pydantic>=1.8.2,<2.0.0", "pyyaml>=5.4.1,<6", "semantic-version>=2.8.5,<3", ], diff --git a/tests/functional/test_schema_fuzzing.py b/tests/functional/test_schema_fuzzing.py index 99579d6..3b93118 100644 --- a/tests/functional/test_schema_fuzzing.py +++ b/tests/functional/test_schema_fuzzing.py @@ -2,14 +2,29 @@ import requests # type: ignore from hypothesis import HealthCheck, given, settings from hypothesis_jsonschema import from_schema +from pydantic import ValidationError from tokenlists import TokenList TOKENLISTS_SCHEMA = "https://uniswap.org/tokenlist.schema.json" +def clean_iso_timestamps(tl: dict) -> dict: + """ + Timestamps can be in any format, and our processing handles it okay + However, for testing purposes, we want the output format to line up, + and unfortunately there is some ambiguity in ISO timestamp formats. + """ + tl["timestamp"] = tl["timestamp"].replace("Z", "+00:00") + return tl + + @pytest.mark.fuzzing @given(token_list=from_schema(requests.get(TOKENLISTS_SCHEMA).json())) @settings(suppress_health_check=(HealthCheck.too_slow,)) def test_schema(token_list): - assert TokenList.from_dict(token_list).to_dict() == token_list + try: + assert TokenList.parse_obj(token_list).dict() == clean_iso_timestamps(token_list) + + except (ValidationError, ValueError): + pass # Expect these kinds of errors diff --git a/tests/functional/test_uniswap_examples.py b/tests/functional/test_uniswap_examples.py index ad5efe3..bd69e85 100644 --- a/tests/functional/test_uniswap_examples.py +++ b/tests/functional/test_uniswap_examples.py @@ -3,6 +3,7 @@ import github import pytest # type: ignore import requests # type: ignore +from pydantic import ValidationError from tokenlists import TokenList @@ -13,9 +14,15 @@ @pytest.mark.parametrize( - "token_list_file", - UNISWAP_REPO.get_contents("test/schema"), # type: ignore + "token_list_name", + [f.name for f in UNISWAP_REPO.get_contents("test/schema")], # type: ignore ) -def test_uniswap_tokenlists(token_list_file): - token_list = requests.get(UNISWAP_RAW_URL + token_list_file.name).json() - assert TokenList.from_dict(token_list).to_dict() == token_list +def test_uniswap_tokenlists(token_list_name): + token_list = requests.get(UNISWAP_RAW_URL + token_list_name).json() + + if "invalid" not in token_list_name: + assert TokenList.parse_obj(token_list).dict() == token_list + + else: + with pytest.raises((ValidationError, ValueError)): + TokenList.parse_obj(token_list).dict() diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 252532f..8a2939c 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -3,14 +3,14 @@ import pytest # type: ignore from click.testing import CliRunner -from tokenlists import TokenListManager, _cli +from tokenlists import _cli, config @pytest.fixture def runner(monkeypatch): runner = CliRunner() with runner.isolated_filesystem() as temp_dir: - monkeypatch.setattr(_cli, "TokenListManager", lambda: TokenListManager(Path(temp_dir))) + monkeypatch.setattr(config, "DEFAULT_CACHE_PATH", Path(temp_dir)) yield runner diff --git a/tests/integration/test_cli.py b/tests/integration/test_cli.py index 3a0602b..1ef5707 100644 --- a/tests/integration/test_cli.py +++ b/tests/integration/test_cli.py @@ -1,6 +1,14 @@ +from tokenlists.version import version + TEST_URI = "tokens.1inch.eth" +def test_version(runner, cli): + result = runner.invoke(cli, ["--version"]) + assert result.exit_code == 0 + assert result.output.strip() == version + + def test_empty_list(runner, cli): result = runner.invoke(cli, ["list"]) assert result.exit_code == 0 @@ -35,3 +43,12 @@ def test_remove(runner, cli): result = runner.invoke(cli, ["list"]) assert result.exit_code == 0 assert "No tokenlists exist" in result.output + + +def test_default(runner, cli): + result = runner.invoke(cli, ["install", TEST_URI]) + assert result.exit_code == 0 + + result = runner.invoke(cli, ["set-default", "1inch"]) + assert result.exit_code == 0 + assert "1inch" in result.output diff --git a/tokenlists/_cli.py b/tokenlists/_cli.py index da6b0ed..21fbfb5 100644 --- a/tokenlists/_cli.py +++ b/tokenlists/_cli.py @@ -16,6 +16,7 @@ def choices(self): @click.group() +@click.version_option(message="%(version)s", package_name="tokenlists") def cli(): """ Utility for working with the `py-tokenlists` installed token lists @@ -59,6 +60,8 @@ def set_default(name): manager.set_default_tokenlist(name) + click.echo(f"Default tokenlist is now: '{manager.default_tokenlist}'") + @cli.command(short_help="Display the names and versions of all installed tokenlists") @click.option("--search", default="") @@ -76,7 +79,7 @@ def list_tokens(search, tokenlist_name, chain_id): lambda t: pattern.match(t.symbol), manager.get_tokens(tokenlist_name, chain_id), ): - click.echo("{address} ({symbol})".format(**token_info.to_dict())) + click.echo("{address} ({symbol})".format(**token_info.dict())) @cli.command(short_help="Display the info for a particular token") @@ -91,6 +94,10 @@ def token_info(symbol, tokenlist_name, chain_id, case_insensitive): raise click.ClickException("No tokenlists available!") token_info = manager.get_token_info(symbol, tokenlist_name, chain_id, case_insensitive) + token_info = token_info.dict() + + if "tags" not in token_info: + token_info["tags"] = "" click.echo( """ @@ -101,6 +108,6 @@ def token_info(symbol, tokenlist_name, chain_id, case_insensitive): Decimals: {decimals} Tags: {tags} """.format( - tags=[], **token_info.to_dict() + **token_info ) ) diff --git a/tokenlists/manager.py b/tokenlists/manager.py index 55052e0..c0775ea 100644 --- a/tokenlists/manager.py +++ b/tokenlists/manager.py @@ -1,34 +1,34 @@ -import json -from pathlib import Path -from typing import Dict, Iterator, List, Optional +from typing import Iterator, List, Optional import requests # type: ignore -from dataclassy import dataclass from tokenlists import config from tokenlists.typing import ChainId, TokenInfo, TokenList, TokenSymbol -@dataclass class TokenListManager: - cache_folder: Path = config.DEFAULT_CACHE_PATH - installed_tokenlists: Dict[str, TokenList] = {} - default_tokenlist: Optional[str] = config.DEFAULT_TOKENLIST - - def __post_init__(self): + def __init__(self): # NOTE: Folder should always exist, even if empty + self.cache_folder = config.DEFAULT_CACHE_PATH self.cache_folder.mkdir(exist_ok=True) # Load all the ones cached on disk + self.installed_tokenlists = {} for path in self.cache_folder.glob("*.json"): - with path.open() as fp: - tokenlist = TokenList.from_dict(json.load(fp)) - self.installed_tokenlists[tokenlist.name] = tokenlist + tokenlist = TokenList.parse_file(path) + self.installed_tokenlists[tokenlist.name] = tokenlist + + self.default_tokenlist = config.DEFAULT_TOKENLIST + if not self.default_tokenlist: + # Default might be cached on disk (does not override config) + default_tokenlist_cachefile = self.cache_folder.joinpath(".default") + + if default_tokenlist_cachefile.exists(): + self.default_tokenlist = default_tokenlist_cachefile.read_text() - # Default might be cached on disk (does not override config) - default_tokenlist_cachefile = self.cache_folder.joinpath(".default") - if not self.default_tokenlist and default_tokenlist_cachefile.exists(): - self.default_tokenlist = default_tokenlist_cachefile.read_text() + elif len(self.installed_tokenlists) > 0: + # Not cached on disk, use first installed list + self.default_tokenlist = next(iter(self.installed_tokenlists)) def install_tokenlist(self, uri: str): # This supports ENS lists @@ -36,13 +36,13 @@ def install_tokenlist(self, uri: str): uri = config.UNISWAP_ENS_TOKENLISTS_HOST.format(uri) # Load and store the tokenlist - tokenlist = TokenList.from_dict(requests.get(uri).json()) + tokenlist = TokenList.parse_obj(requests.get(uri).json()) self.installed_tokenlists[tokenlist.name] = tokenlist # Cache it on disk for later instances + self.cache_folder.mkdir(exist_ok=True) token_list_file = self.cache_folder.joinpath(f"{tokenlist.name}.json") - with token_list_file.open("w") as fp: - json.dump(tokenlist.to_dict(), fp) + token_list_file.write_text(tokenlist.json()) def remove_tokenlist(self, tokenlist_name: str) -> None: tokenlist = self.installed_tokenlists[tokenlist_name] @@ -62,6 +62,7 @@ def set_default_tokenlist(self, name: str) -> None: self.default_tokenlist = name # Cache it on disk too + self.cache_folder.mkdir(exist_ok=True) self.cache_folder.joinpath(".default").write_text(name) def available_tokenlists(self) -> List[str]: @@ -85,7 +86,7 @@ def get_tokens( chain_id: ChainId = 1, # Ethereum Mainnnet ) -> Iterator[TokenInfo]: tokenlist = self.get_tokenlist(token_listname) - return filter(lambda t: t.chainId == chain_id, iter(tokenlist)) + return filter(lambda t: t.chainId == chain_id, tokenlist.tokens) def get_token_info( self, @@ -96,7 +97,7 @@ def get_token_info( ) -> TokenInfo: tokenlist = self.get_tokenlist(token_listname) - token_iter = filter(lambda t: t.chainId == chain_id, iter(tokenlist)) + token_iter = filter(lambda t: t.chainId == chain_id, tokenlist.tokens) token_iter = ( filter(lambda t: t.symbol == symbol, token_iter) if case_insensitive diff --git a/tokenlists/typing.py b/tokenlists/typing.py index 0ac0fd1..8636e2f 100644 --- a/tokenlists/typing.py +++ b/tokenlists/typing.py @@ -1,14 +1,15 @@ -from copy import deepcopy -from datetime import datetime as DateTime -from typing import Dict, Iterator, List, Optional, Union +from datetime import datetime +from itertools import chain +from typing import Dict, List, Optional -import dataclassy as dc +from pydantic import AnyUrl +from pydantic import BaseModel as _BaseModel +from pydantic import validator from semantic_version import Version # type: ignore ChainId = int TagId = str -URI = str TokenAddress = str TokenName = str @@ -16,77 +17,125 @@ TokenSymbol = str -@dc.dataclass(frozen=True, slots=True) -class TokenInfo: +class BaseModel(_BaseModel): + def dict(self, *args, **kwargs): + if "exclude_unset" not in kwargs: + kwargs["exclude_unset"] = True + + return super().dict(*args, **kwargs) + + class Config: + froze = True + + +class TokenInfo(BaseModel): chainId: ChainId address: TokenAddress name: TokenName decimals: TokenDecimals symbol: TokenSymbol - logoURI: Optional[str] = None + logoURI: Optional[AnyUrl] = None tags: Optional[List[TagId]] = None - extensions: Optional[Dict[str, Union[str, int, bool]]] = None - - @classmethod - def from_dict(cls, data: Dict) -> "TokenInfo": - data = deepcopy(data) - return cls(**data) # type: ignore - - def to_dict(self) -> Dict: - data = dc.asdict(self) - if self.logoURI is None: - del data["logoURI"] - if self.tags is None: - del data["tags"] - if self.extensions is None: - del data["extensions"] - return data + extensions: Optional[dict] = None + + @validator("address") + def address_must_hex(cls, v: str): + if not v.startswith("0x") or set(v) > set("x0123456789abcdefABCDEF") or len(v) % 2 != 0: + raise ValueError("Address is not hex") + + address_bytes = bytes.fromhex(v[2:]) # NOTE: Skip `0x` + + if len(address_bytes) != 20: + raise ValueError("Address is incorrect length") + + return v + + @validator("decimals") + def decimals_must_be_uint8(cls, v: TokenDecimals): + if not (0 <= v < 256): + raise ValueError(f"Invalid token decimals: {v}") + return v -class Timestamp(DateTime): - def __init__(self, timestamp: str): - super().fromisoformat(timestamp) + @validator("extensions") + def extensions_must_contain_simple_types(cls, d: Optional[dict]) -> Optional[dict]: + if not d: + return d + # `extensions` is `Dict[str, Union[str, int, bool, None]]`, but pydantic mutates entries + for val in d.values(): + if not isinstance(val, (str, int, bool)) and val is not None: + raise ValueError(f"Incorrect extension field value: {val}") -@dc.dataclass(frozen=True, slots=True) -class Tag: + return d + + +class Tag(BaseModel): name: str description: str -# NOTE: Not frozen as we may need to dynamically modify this -@dc.dataclass(slots=True) -class TokenList: +class TokenListVersion(BaseModel, Version): + major: int + minor: int + patch: int + + @validator("*") + def no_negative_version_numbers(cls, v: int): + if v < 0: + raise ValueError("Invalid version number") + + return v + + # NOTE: These properties are just here so that we can use `Version` properly + @property + def prerelease(self): + return None + + @property + def build(self): + return None + + def __str__(self) -> str: + # NOTE: This custom string function is necessary because + # `super(Version, self).__str__()` isn't working right + return f"{self.major}.{self.minor}.{self.patch}" + + +class TokenList(BaseModel): name: str - timestamp: Timestamp - version: Version + timestamp: datetime + version: TokenListVersion tokens: List[TokenInfo] keywords: Optional[List[str]] = None tags: Optional[Dict[TagId, Tag]] = None - logoURI: Optional[URI] = None - - def __iter__(self) -> Iterator[TokenInfo]: - return iter(self.tokens) - - @classmethod - def from_dict(cls, data: Dict) -> "TokenList": - data = deepcopy(data) - data["version"] = Version(**data["version"]) - data["tokens"] = [TokenInfo.from_dict(t) for t in data["tokens"]] - return cls(**data) # type: ignore - - def to_dict(self) -> Dict: - data = dc.asdict(self) - data["version"] = { - "major": self.version.major, - "minor": self.version.minor, - "patch": self.version.patch, - } - data["tokens"] = [t.to_dict() for t in self.tokens] - if self.keywords is None: - del data["keywords"] - if self.tags is None: - del data["tags"] - if self.logoURI is None or self.logoURI == "": - del data["logoURI"] + logoURI: Optional[AnyUrl] = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # Pull all the tags from all the tokens, reference or not + all_tags = chain.from_iterable( + list(token.tags) if token.tags else [] for token in self.tokens + ) + # Obtain the set of all enumerated reference tags e.g. "1", "2", etc. + token_ref_tags = set(tag for tag in set(all_tags) if set(tag) < set("0123456789")) + + # Compare the enumerated reference tags from the tokens to the tag set in this class + tokenlist_tags = set(iter(self.tags)) if self.tags else set() + if token_ref_tags > tokenlist_tags: + # We have an enumerated reference tag used by a token + # missing from the our set of tags here + raise ValueError( + f"Missing reference tags in tokenlist: {token_ref_tags - tokenlist_tags}" + ) + + class Config: + # NOTE: Not frozen as we may need to dynamically modify this + froze = False + + def dict(self, *args, **kwargs) -> dict: + data = super().dict(*args, **kwargs) + # NOTE: This was the easiest way to make sure this property returns isoformat + data["timestamp"] = self.timestamp.isoformat() return data