Skip to content

Commit

Permalink
chore: scopey fix
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Nov 5, 2024
1 parent 3e9fdb4 commit e365767
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 71 deletions.
96 changes: 31 additions & 65 deletions src/ape/pytest/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Iterable, Iterator, Mapping
from dataclasses import dataclass, field
from fnmatch import fnmatch
from functools import cached_property
from functools import cached_property, singledispatchmethod
from typing import TYPE_CHECKING, ClassVar, Optional

import pytest
Expand All @@ -13,6 +13,7 @@

from ape.exceptions import BlockNotFoundError, ChainError, ProviderNotConnectedError
from ape.logging import logger
from ape.pytest.utils import Scope
from ape.utils.basemodel import ManagerAccessMixin
from ape.utils.rpc import allow_disconnected

Expand All @@ -23,14 +24,13 @@
from ape.managers.networks import NetworkManager
from ape.managers.project import ProjectManager
from ape.pytest.config import ConfigWrapper
from ape.pytest.utils import Scope
from ape.types.vm import SnapshotID


@dataclass()
class FixtureRebase:
return_scope: "Scope"
invalid_fixtures: dict["Scope", list[str]]
return_scope: Scope
invalid_fixtures: dict[Scope, list[str]]


class FixtureManager(ManagerAccessMixin):
Expand Down Expand Up @@ -106,7 +106,7 @@ def cache_fixtures(self, item) -> "FixtureMap":

return fixture_map

def get_fixture_scope(self, fixture_name: str) -> Optional["Scope"]:
def get_fixture_scope(self, fixture_name: str) -> Optional[Scope]:
return self._fixture_name_to_info.get(fixture_name, {}).get("scope")

def is_stateful(self, name: str) -> Optional[bool]:
Expand Down Expand Up @@ -160,7 +160,7 @@ def add_fixture_info(self, name: str, **info):
def _get_cached_fixtures(self, nodeid: str) -> Optional["FixtureMap"]:
return self._nodeid_to_fixture_map.get(nodeid)

def rebase(self, scope: "Scope", fixtures: "FixtureMap"):
def rebase(self, scope: Scope, fixtures: "FixtureMap"):
if not (rebase := self._get_rebase(scope)):
# Rebase avoided: nothing would change.
return
Expand Down Expand Up @@ -199,7 +199,7 @@ def rebase(self, scope: "Scope", fixtures: "FixtureMap"):
log = f"{log} invalidated-fixtures='{', '.join(invalidated)}'"
self.isolation_manager._records.append(log)

def _get_rebase(self, scope: "Scope") -> Optional[FixtureRebase]:
def _get_rebase(self, scope: Scope) -> Optional[FixtureRebase]:
# Check for fixtures that are now invalid. For example, imagine a session
# fixture comes into play after the module snapshot has been set.
# Once we restore the module's state and move to the next module,
Expand Down Expand Up @@ -232,10 +232,8 @@ def _get_rebase(self, scope: "Scope") -> Optional[FixtureRebase]:
)


class FixtureMap(dict["Scope", list[str]]):
class FixtureMap(dict[Scope, list[str]]):
def __init__(self, item):
from ape.pytest.utils import Scope

self._item = item
self._parametrized_names: Optional[list[str]] = None
super().__init__(
Expand Down Expand Up @@ -269,8 +267,6 @@ def names(self) -> list[str]:
Outputs in correct order for item.fixturenames.
Also, injects isolation fixtures if needed.
"""
from ape.pytest.utils import Scope

result = []
for scope, ls in self.items():
# NOTE: For function scoped, we always add the isolation fixture.
Expand Down Expand Up @@ -317,21 +313,15 @@ def parametrized(self) -> dict[str, list]:
def _arg2fixturedefs(self) -> Mapping:
return self._item.session._fixturemanager._arg2fixturedefs

@singledispatchmethod
def __setitem__(self, key, value):
# NOTE: Not using singledispatchmethod because it requires
# types at runtime.
if isinstance(key, Scope):
self.__setitem_scope(key, value)
elif isinstance(key, str):
self.__setitem_str(key, value)
elif isinstance(key, int):
self.__setitem_int(key, value)

raise NotImplementedError(type(key))

@__setitem__.register
def __setitem_int(self, key: int, value: list[str]):
super().__setitem__(Scope(key), value)

@__setitem__.register
def __setitem_str(self, key: str, value: list[str]):
for scope in Scope:
if f"{scope}" == key:
Expand All @@ -340,38 +330,28 @@ def __setitem_str(self, key: str, value: list[str]):

raise KeyError(key)

def __setitem_scope(self, key: "Scope", value: list[str]):
@__setitem__.register
def __setitem_scope(self, key: Scope, value: list[str]):
super().__setitem__(key, value)

@singledispatchmethod
def __getitem__(self, key):
# NOTE: Not using singledispatchmethod because it requires
# types at runtime.
from ape.pytest.utils import Scope

if isinstance(key, Scope):
return self.__getitem_scope(key)
elif isinstance(key, str):
return self.__getitem_str(key)
elif isinstance(key, int):
return self.__getitem_int(key)

raise NotImplementedError(type(key))

@__getitem__.register
def __getitem_int(self, key: int) -> list[str]:
from ape.pytest.utils import Scope

return super().__getitem__(Scope(key))

@__getitem__.register
def __getitem_str(self, key: str) -> list[str]:
from ape.pytest.utils import Scope

for scope in Scope:
if f"{scope}" == key:
return super().__getitem__(scope)

raise KeyError(key)

def __getitem_scope(self, key: "Scope") -> list[str]:
@__getitem__.register
def __getitem_scope(self, key: Scope) -> list[str]:
return super().__getitem__(key)

def get_info(self, name: str) -> list:
Expand Down Expand Up @@ -485,32 +465,22 @@ def Contract(self):

@pytest.fixture(scope="session")
def _session_isolation(self) -> Iterator[None]:
from ape.pytest.utils import Scope

yield from self.isolation_manager.isolation(Scope.SESSION)

@pytest.fixture(scope="package")
def _package_isolation(self) -> Iterator[None]:
from ape.pytest.utils import Scope

yield from self.isolation_manager.isolation(Scope.PACKAGE)

@pytest.fixture(scope="module")
def _module_isolation(self) -> Iterator[None]:
from ape.pytest.utils import Scope

yield from self.isolation_manager.isolation(Scope.MODULE)

@pytest.fixture(scope="class")
def _class_isolation(self) -> Iterator[None]:
from ape.pytest.utils import Scope

yield from self.isolation_manager.isolation(Scope.CLASS)

@pytest.fixture(scope="function")
def _function_isolation(self) -> Iterator[None]:
from ape.pytest.utils import Scope

yield from self.isolation_manager.isolation(Scope.FUNCTION)


Expand All @@ -520,7 +490,7 @@ class Snapshot:
All the data necessary for accurately supporting isolation.
"""

scope: "Scope"
scope: Scope
"""Corresponds to fixture scope."""

identifier: Optional["SnapshotID"] = None
Expand All @@ -537,10 +507,8 @@ def append_fixtures(self, fixtures: Iterable[str]):
self.fixtures.append(fixture)


class SnapshotRegistry(dict["Scope", Snapshot]):
class SnapshotRegistry(dict[Scope, Snapshot]):
def __init__(self):
from ape.pytest.utils import Scope

super().__init__(
{
Scope.SESSION: Snapshot(Scope.SESSION),
Expand All @@ -551,22 +519,20 @@ def __init__(self):
}
)

def get_snapshot_id(self, scope: "Scope") -> Optional["SnapshotID"]:
def get_snapshot_id(self, scope: Scope) -> Optional["SnapshotID"]:
return self[scope].identifier

def set_snapshot_id(self, scope: "Scope", snapshot_id: "SnapshotID"):
def set_snapshot_id(self, scope: Scope, snapshot_id: "SnapshotID"):
self[scope].identifier = snapshot_id

def clear_snapshot_id(self, scope: "Scope"):
def clear_snapshot_id(self, scope: Scope):
self[scope].identifier = None

def next_snapshots(self, scope: "Scope") -> Iterator[Snapshot]:
from ape.pytest.utils import Scope

def next_snapshots(self, scope: Scope) -> Iterator[Snapshot]:
for scope_value in range(scope + 1, Scope.FUNCTION + 1):
yield self[scope_value] # type: ignore

def extend_fixtures(self, scope: "Scope", fixtures: Iterable[str]):
def extend_fixtures(self, scope: Scope, fixtures: Iterable[str]):
self[scope].fixtures.extend(fixtures)


Expand All @@ -587,16 +553,16 @@ def _track_transactions(self) -> bool:
and (self.config_wrapper.track_gas or self.config_wrapper.track_coverage)
)

def get_snapshot(self, scope: "Scope") -> Snapshot:
def get_snapshot(self, scope: Scope) -> Snapshot:
return self.snapshots[scope]

def extend_fixtures(self, scope: "Scope", fixtures: Iterable[str]):
def extend_fixtures(self, scope: Scope, fixtures: Iterable[str]):
self.snapshots.extend_fixtures(scope, fixtures)

def next_snapshots(self, scope: "Scope") -> Iterator[Snapshot]:
def next_snapshots(self, scope: Scope) -> Iterator[Snapshot]:
yield from self.snapshots.next_snapshots(scope)

def isolation(self, scope: "Scope") -> Iterator[None]:
def isolation(self, scope: Scope) -> Iterator[None]:
"""
Isolation logic used to implement isolation fixtures for each pytest scope.
When tracing support is available, will also assist in capturing receipts.
Expand All @@ -623,7 +589,7 @@ def isolation(self, scope: "Scope") -> Iterator[None]:

self.restore(scope)

def set_snapshot(self, scope: "Scope"):
def set_snapshot(self, scope: Scope):
# Also can be used to re-set snapshot.
if not self.supported:
return
Expand Down Expand Up @@ -654,7 +620,7 @@ def take_snapshot(self) -> Optional["SnapshotID"]:
return None

@allow_disconnected
def restore(self, scope: "Scope"):
def restore(self, scope: Scope):
snapshot_id = self.snapshots.get_snapshot_id(scope)
if snapshot_id is None:
return
Expand Down
10 changes: 4 additions & 6 deletions src/ape/pytest/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
import click
import pytest
from _pytest._code.code import Traceback as PytestTraceback
from _pytest.reports import TestReport
from rich import print as rich_print

from ape.exceptions import ConfigError, ProviderNotConnectedError
from ape.logging import LogLevel
from ape.pytest.utils import Scope
from ape.utils.basemodel import ManagerAccessMixin

if TYPE_CHECKING:
from _pytest.reports import TestReport

from ape.api.networks import ProviderContextManager
from ape.pytest.config import ConfigWrapper
from ape.pytest.coverage import CoverageTracker
Expand Down Expand Up @@ -171,8 +173,6 @@ def pytest_runtest_setup(self, item):
self._setup_isolation(item)

def _setup_isolation(self, item):
from ape.pytest.utils import Scope

fixtures = self.fixture_manager.get_fixtures(item)
for scope in (Scope.SESSION, Scope.PACKAGE, Scope.MODULE, Scope.CLASS):
if not (
Expand Down Expand Up @@ -273,8 +273,6 @@ def pytest_fixture_post_finalizer(self, fixturedef, request):
self.fixture_manager.add_fixture_info(fixture_name, teardown_block=block_number)

def _track_fixture_blocks(self, fixture_name: str) -> bool:
from ape.pytest.utils import Scope

if not self.fixture_manager.is_custom(fixture_name):
return False

Expand Down Expand Up @@ -312,7 +310,7 @@ def _connect(self):
self._provider_context.push_provider()
self._provider_is_connected = True

def pytest_runtest_logreport(self, report: TestReport):
def pytest_runtest_logreport(self, report: "TestReport"):
if self.config_wrapper.verbosity >= 3:
self.isolation_manager.show_records()

Expand Down

0 comments on commit e365767

Please sign in to comment.