Skip to content

Commit

Permalink
fix: scope order issue
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey committed Sep 12, 2024
1 parent dafed6a commit 67a1d25
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 131 deletions.
197 changes: 147 additions & 50 deletions src/ape/pytest/fixtures.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from collections.abc import Iterator
import inspect
from collections.abc import Iterable, Iterator
from dataclasses import dataclass, field
from fnmatch import fnmatch
from functools import cached_property
from typing import Optional
Expand All @@ -14,6 +16,7 @@
from ape.managers.networks import NetworkManager
from ape.managers.project import ProjectManager
from ape.pytest.config import ConfigWrapper
from ape.pytest.utils import Scope
from ape.types import SnapshotID
from ape.utils.basemodel import ManagerAccessMixin
from ape.utils.rpc import allow_disconnected
Expand All @@ -24,22 +27,11 @@ class PytestApeFixtures(ManagerAccessMixin):
# for fixtures, as they are used in output from the command
# `ape test -q --fixture` (`pytest -q --fixture`).

_supports_snapshot: bool = True
receipt_capture: "ReceiptCapture"

def __init__(self, config_wrapper: ConfigWrapper, receipt_capture: "ReceiptCapture"):
def __init__(self, config_wrapper: ConfigWrapper, isolation_manager: "IsolationManager"):
self.config_wrapper = config_wrapper
self.receipt_capture = receipt_capture

@cached_property
def _track_transactions(self) -> bool:
return (
self.network_manager.provider is not None
and self.provider.is_connected
and (self.config_wrapper.track_gas or self.config_wrapper.track_coverage)
)
self.isolation_manager = isolation_manager

@pytest.fixture(scope="session")
@pytest.fixture(scope=Scope.SESSION.value)
def accounts(self) -> list[TestAccountAPI]:
"""
A collection of pre-funded accounts.
Expand All @@ -53,48 +45,124 @@ def compilers(self):
"""
return self.compiler_manager

@pytest.fixture(scope="session")
@pytest.fixture(scope=Scope.SESSION.value)
def chain(self) -> ChainManager:
"""
Manipulate the blockchain, such as mine or change the pending timestamp.
"""
return self.chain_manager

@pytest.fixture(scope="session")
@pytest.fixture(scope=Scope.SESSION.value)
def networks(self) -> NetworkManager:
"""
Connect to other networks in your tests.
"""
return self.network_manager

@pytest.fixture(scope="session")
@pytest.fixture(scope=Scope.SESSION.value)
def project(self) -> ProjectManager:
"""
Access contract types and dependencies.
"""
return self.local_project

@pytest.fixture(scope="session")
@pytest.fixture(scope=Scope.SESSION.value)
def Contract(self):
"""
Instantiate a reference to an on-chain contract
using its address (same as ``ape.Contract``).
"""
return self.chain_manager.contracts.instance_at

def _isolation(self, request=None) -> Iterator[None]:
# isolation fixtures
@pytest.fixture(scope=Scope.SESSION.value)
def _session_isolation(self) -> Iterator[None]:
yield from self.isolation_manager.isolation(Scope.SESSION)

@pytest.fixture(scope=Scope.PACKAGE.value)
def _package_isolation(self) -> Iterator[None]:
yield from self.isolation_manager.isolation(Scope.PACKAGE)

@pytest.fixture(scope=Scope.MODULE.value)
def _module_isolation(self) -> Iterator[None]:
yield from self.isolation_manager.isolation(Scope.MODULE)

@pytest.fixture(scope=Scope.CLASS.value)
def _class_isolation(self) -> Iterator[None]:
yield from self.isolation_manager.isolation(Scope.CLASS)

@pytest.fixture(scope=Scope.FUNCTION.value)
def _function_isolation(self) -> Iterator[None]:
yield from self.isolation_manager.isolation(Scope.FUNCTION)


@dataclass
class Snapshot:
scope: "Scope" # Assuming 'Scope' is defined elsewhere
identifier: Optional[str] = None
fixtures: list = field(default_factory=list)


class IsolationManager(ManagerAccessMixin):
INVALID_KEY = "__invalid_snapshot__"

_supported: bool = True
_snapshot_registry: dict[Scope, Snapshot] = {
Scope.SESSION: Snapshot(Scope.SESSION),
Scope.PACKAGE: Snapshot(Scope.PACKAGE),
Scope.MODULE: Snapshot(Scope.MODULE),
Scope.CLASS: Snapshot(Scope.CLASS),
Scope.FUNCTION: Snapshot(Scope.FUNCTION),
}

def __init__(self, config_wrapper: ConfigWrapper, receipt_capture: "ReceiptCapture"):
self.config_wrapper = config_wrapper
self.receipt_capture = receipt_capture

@cached_property
def builtin_ape_fixtures(self) -> tuple[str, ...]:
return tuple(
[
n
for n, itm in inspect.getmembers(PytestApeFixtures)
if callable(itm) and not n.startswith("_")
]
)

@cached_property
def _track_transactions(self) -> bool:
return (
self.network_manager.provider is not None
and self.provider.is_connected
and (self.config_wrapper.track_gas or self.config_wrapper.track_coverage)
)

def update_fixtures(self, scope: Scope, fixtures: Iterable[str]):
snapshot = self._snapshot_registry[scope]
if not (
new_fixtures := [
p
for p in fixtures
if p not in snapshot.fixtures and p not in self.builtin_ape_fixtures
]
):
return

# If the snapshot is already set, we have to invalidate it.
# We need to replace the snapshot with one that happens after
# the new fixtures.
if snapshot is not None:
self._snapshot_registry[scope].identifier = self.INVALID_KEY

# Add or update peer-fixtures.
self._snapshot_registry[scope].fixtures.extend(new_fixtures)

def isolation(self, scope: Scope) -> Iterator[None]:
"""
Isolation logic used to implement isolation fixtures for each pytest scope.
When tracing support is available, will also assist in capturing receipts.
"""
snapshot_id = None

if self._supports_snapshot:
try:
snapshot_id = self._snapshot()
except BlockNotFoundError:
self._supports_snapshot = False

self._set_snapshot(scope)
if self._track_transactions:
did_yield = False
try:
Expand All @@ -106,33 +174,54 @@ def _isolation(self, request=None) -> Iterator[None]:
if not did_yield:
# Prevent double yielding.
yield

else:
yield

# NOTE: self._supports_snapshot may have gotten set to False
# NOTE: self._supported may have gotten set to False
# someplace else _after_ snapshotting succeeded.
if self._supports_snapshot and snapshot_id is not None:
self._restore(snapshot_id)
if not self._supported:
return

# isolation fixtures
@pytest.fixture(scope="session")
def _session_isolation(self, request) -> Iterator[None]:
yield from self._isolation(request)
self._restore(scope)

@pytest.fixture(scope="package")
def _package_isolation(self, request) -> Iterator[None]:
yield from self._isolation(request)
def _set_snapshot(self, scope: Scope):
# Also can be used to re-set snapshot.
if not self._supported:
return

@pytest.fixture(scope="module")
def _module_isolation(self, request) -> Iterator[None]:
yield from self._isolation(request)
# Here is something tricky: If a snapshot exists
# already at a lower-level, we must use that one.
# Like if a session comes in _after_ a module, have
# the session just use the module.
# Else, it falls apart.
snapshot_id = None
if scope is not Scope.FUNCTION:
lower_scopes: tuple[Scope, ...] = ()
if scope is Scope.SESSION:
lower_scopes = (Scope.PACKAGE, Scope.MODULE, Scope.CLASS, Scope.FUNCTION)
elif scope is Scope.PACKAGE:
lower_scopes = (Scope.MODULE, Scope.CLASS, Scope.FUNCTION)
elif scope is Scope.MODULE:
lower_scopes = (Scope.CLASS, Scope.FUNCTION)
elif scope is Scope.CLASS:
lower_scopes = (Scope.FUNCTION,)
for lower_scope in lower_scopes:
snapshot = self._snapshot_registry[lower_scope]
if snapshot.identifier is not None:
snapshot_id = snapshot.identifier
break

if snapshot_id is None:
try:
snapshot_id = self._take_snapshot()
except Exception:
self._supported = False

_class_isolation = pytest.fixture(_isolation, scope="class")
_function_isolation = pytest.fixture(_isolation, scope="function")
if snapshot_id is not None:
self._snapshot_registry[scope].identifier = snapshot_id

@allow_disconnected
def _snapshot(self) -> Optional[SnapshotID]:
def _take_snapshot(self) -> Optional[SnapshotID]:
try:
return self.chain_manager.snapshot()
except NotImplementedError:
Expand All @@ -141,14 +230,21 @@ def _snapshot(self) -> Optional[SnapshotID]:
"Tests will not be completely isolated."
)
# To avoid trying again
self._supports_snapshot = False
self._supported = False

return None

@allow_disconnected
def _restore(self, snapshot_id: SnapshotID):
if snapshot_id not in self.chain_manager._snapshots:
def _restore(self, scope: Scope):
snapshot_id = self._snapshot_registry[scope].identifier
if snapshot_id is None:
return

elif snapshot_id not in self.chain_manager._snapshots or snapshot_id == self.INVALID_KEY:
# Still clear out.
self._snapshot_registry[scope].identifier = None
return

try:
self.chain_manager.restore(snapshot_id)
except NotImplementedError:
Expand All @@ -157,11 +253,12 @@ def _restore(self, snapshot_id: SnapshotID):
"Tests will not be completely isolated."
)
# To avoid trying again
self._supports_snapshot = False
self._supported = False

self._snapshot_registry[scope].identifier = None


class ReceiptCapture(ManagerAccessMixin):
config_wrapper: ConfigWrapper
receipt_map: dict[str, dict[str, ReceiptAPI]] = {}
enter_blocks: list[int] = []

Expand Down
9 changes: 6 additions & 3 deletions src/ape/pytest/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ape.exceptions import ConfigError
from ape.pytest.config import ConfigWrapper
from ape.pytest.coverage import CoverageTracker
from ape.pytest.fixtures import PytestApeFixtures, ReceiptCapture
from ape.pytest.fixtures import IsolationManager, PytestApeFixtures, ReceiptCapture
from ape.pytest.gas import GasTracker
from ape.pytest.runners import PytestApeRunner
from ape.utils.basemodel import ManagerAccessMixin
Expand Down Expand Up @@ -77,21 +77,24 @@ def is_module(v):

config_wrapper = ConfigWrapper(config)
receipt_capture = ReceiptCapture(config_wrapper)
isolation_manager = IsolationManager(config_wrapper, receipt_capture)
gas_tracker = GasTracker(config_wrapper)
coverage_tracker = CoverageTracker(config_wrapper)

# Enable verbose output if stdout capture is disabled
config.option.verbose = config.getoption("capture") == "no"

# Register the custom Ape test runner
runner = PytestApeRunner(config_wrapper, receipt_capture, gas_tracker, coverage_tracker)
runner = PytestApeRunner(
config_wrapper, isolation_manager, receipt_capture, gas_tracker, coverage_tracker
)
config.pluginmanager.register(runner, "ape-test")

# Inject runner for access to gas and coverage trackers.
ManagerAccessMixin._test_runner = runner

# Include custom fixtures for project, accounts etc.
fixtures = PytestApeFixtures(config_wrapper, receipt_capture)
fixtures = PytestApeFixtures(config_wrapper, isolation_manager)
config.pluginmanager.register(fixtures, "ape-fixtures")

# Add custom markers
Expand Down
Loading

0 comments on commit 67a1d25

Please sign in to comment.