Skip to content

Commit

Permalink
fix: issue where call coverage didn't increment (#2105)
Browse files Browse the repository at this point in the history
  • Loading branch information
antazoey authored Jun 1, 2024
1 parent c893851 commit 5665aba
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 22 deletions.
47 changes: 32 additions & 15 deletions src/ape/pytest/coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ethpm_types.source import ContractSource

from ape.logging import logger
from ape.managers import ProjectManager
from ape.pytest.config import ConfigWrapper
from ape.types import (
ContractFunctionPath,
Expand All @@ -22,8 +23,8 @@


class CoverageData(ManagerAccessMixin):
def __init__(self, base_path: Path, sources: Iterable[ContractSource]):
self.base_path = base_path
def __init__(self, project: ProjectManager, sources: Iterable[ContractSource]):
self.project = project
self.sources = list(sources)
self._report: Optional[CoverageReport] = None
self._init_coverage_profile() # Inits self._report.
Expand All @@ -43,7 +44,7 @@ def _init_coverage_profile(
self,
) -> CoverageReport:
# source_id -> pc(s) -> times hit
project_coverage = CoverageProject(name=self.config_manager.name or "__local__")
project_coverage = CoverageProject(name=self.project.name or "__local__")

for src in self.sources:
source_cov = project_coverage.include(src)
Expand All @@ -60,11 +61,11 @@ def _init_coverage_profile(
timestamp = get_current_timestamp_ms()
report = CoverageReport(
projects=[project_coverage],
source_folders=[self.local_project.contracts_folder],
source_folders=[self.project.contracts_folder],
timestamp=timestamp,
)

# Remove emptys.
# Remove empties.
for project in report.projects:
project.sources = [x for x in project.sources if len(x.statements) > 0]

Expand All @@ -74,7 +75,11 @@ def _init_coverage_profile(
def cover(
self, src_path: Path, pcs: Iterable[int], inc_fn_hits: bool = True
) -> tuple[set[int], list[str]]:
source_id = str(get_relative_path(src_path.absolute(), self.base_path))
if hasattr(self.project, "path"):
source_id = str(get_relative_path(src_path.absolute(), self.project.path))
else:
source_id = str(src_path)

if source_id not in self.report.sources:
# The source is not tracked for coverage.
return set(), []
Expand Down Expand Up @@ -120,14 +125,27 @@ def cover(


class CoverageTracker(ManagerAccessMixin):
def __init__(self, config_wrapper: ConfigWrapper):
def __init__(
self,
config_wrapper: ConfigWrapper,
project: Optional[ProjectManager] = None,
output_path: Optional[Path] = None,
):
self.config_wrapper = config_wrapper
sources = self.local_project._contract_sources
self._project = project or self.local_project

if output_path:
self._output_path = output_path
elif hasattr(self._project, "manifest_path"):
# Local project.
self._output_path = self._project.manifest_path.parent
else:
self._output_path = Path.cwd()

sources = self._project._contract_sources

self.data: Optional[CoverageData] = (
CoverageData(self.local_project.path, sources)
if self.config_wrapper.track_coverage
else None
CoverageData(self._project, sources) if self.config_wrapper.track_coverage else None
)

@property
Expand Down Expand Up @@ -180,7 +198,7 @@ def cover(
for src in project.sources:
# NOTE: We will allow this check to skip if there is no source is the
# traceback. This helps increment methods that are missing from the source map.
path = self.local_project.contracts_folder / src.source_id
path = self._project.path / src.source_id
if source_path is not None and path != source_path:
continue

Expand Down Expand Up @@ -279,7 +297,6 @@ def show_session_coverage(self) -> bool:

# Reports are set in ape-config.yaml.
reports = self.config_wrapper.ape_test_config.coverage.reports
out_folder = self.local_project.manifest_path.parent
if reports.terminal:
verbose = (
reports.terminal.get("verbose", False)
Expand Down Expand Up @@ -308,9 +325,9 @@ def show_session_coverage(self) -> bool:
click.echo()

if self.config_wrapper.xml_coverage:
self.data.report.write_xml(out_folder)
self.data.report.write_xml(self._output_path)
if value := self.config_wrapper.html_coverage:
verbose = value.get("verbose", False) if isinstance(value, dict) else False
self.data.report.write_html(out_folder, verbose=verbose)
self.data.report.write_html(self._output_path, verbose=verbose)

return True
8 changes: 7 additions & 1 deletion tests/functional/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,8 @@ def mock_compiler(mocker):
mock.name = "mock"
mock.ext = ".__mock__"
mock.tracked_settings = []
mock.ast = None
mock.pcmap = None

def mock_compile(paths, project=None, settings=None):
settings = settings or {}
Expand All @@ -691,10 +693,14 @@ def mock_compile(paths, project=None, settings=None):
code = HexBytes(123).hex()
data = {
"contractName": name,
"abi": [],
"abi": mock.abi,
"deploymentBytecode": code,
"sourceId": f"{project.contracts_folder.name}/{path.name}",
}
if ast := mock.ast:
data["ast"] = ast
if pcmap := mock.pcmap:
data["pcmap"] = pcmap

# Check for mocked overrides
overrides = mock.overrides
Expand Down
89 changes: 83 additions & 6 deletions tests/functional/test_coverage.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from pathlib import Path

import pytest
from ethpm_types import MethodABI
from ethpm_types.source import ContractSource, Source

import ape
from ape.pytest.config import ConfigWrapper
from ape.pytest.coverage import CoverageData, CoverageTracker
from ape.types import SourceTraceback
from ape.types.coverage import (
ContractCoverage,
ContractSourceCoverage,
Expand Down Expand Up @@ -172,7 +175,7 @@ def contract_source(self, vyper_contract_type, src):

@pytest.fixture(scope="class")
def coverage_data(self, project, contract_source):
return CoverageData(project.path, (contract_source,))
return CoverageData(project, (contract_source,))

def test_report(self, coverage_data):
actual = coverage_data.report
Expand All @@ -184,13 +187,87 @@ class TestCoverageTracker:
def pytest_config(self, mocker):
return mocker.MagicMock()

@pytest.fixture(scope="class")
@pytest.fixture
def config_wrapper(self, pytest_config):
return ConfigWrapper(pytest_config)

def test_data(self, pytest_config):
tracker = CoverageTracker(pytest_config)
@pytest.fixture
def tracker(self, pytest_config):
return CoverageTracker(pytest_config)

def test_data(self, tracker):
assert tracker.data is not None
actual = tracker.data.base_path
expected = tracker.local_project.path
actual = tracker.data.project
expected = tracker.local_project
assert actual == expected

def test_cover(self, mocker, pytest_config, compilers, mock_compiler):
"""
Ensure coverage of a call works.
"""
filestem = "atest"
filename = f"{filestem}.__mock__"
fn_name = "_a_method"

# Set up the mock compiler.
mock_compiler.abi = [MethodABI(name=fn_name)]
mock_compiler.ast = {
"src": "0:112:0",
"name": filename,
"end_lineno": 7,
"lineno": 1,
"ast_type": "Module",
}
mock_compiler.pcmap = {"0": {"location": (1, 7, 1, 7)}}
mock_contract = mocker.MagicMock()
mock_contract.name = filename
mock_statement = mocker.MagicMock()
mock_statement.pcs = {20}
mock_statement.hit_count = 0
mock_function = mocker.MagicMock()
mock_function.name = fn_name
mock_function.statements = [mock_statement]
mock_contract.functions = [mock_function]
mock_contract.statements = [mock_statement]

def init_profile(source_cov, src):
source_cov.contracts = [mock_contract]

mock_compiler.init_coverage_profile.side_effect = init_profile

stmt = {"type": "dev: Cannot send ether to non-payable function", "pcs": [20]}
fn_name = "_a_method"
tb_data = {
"statements": [stmt],
"closure": {"name": fn_name, "full_name": f"{fn_name}()"},
"depth": 0,
}

with ape.Project.create_temporary_project() as tmp:
# Create a source file.
file = tmp.path / "contracts" / filename
file.parent.mkdir(exist_ok=True, parents=True)
file.write_text("testing")

# Ensure the TB refers to this source.
tb_data["source_path"] = f"{tmp.path}/contracts/{filename}"
call_tb = SourceTraceback.model_validate([tb_data])

try:
# Hack in our mock compiler.
_ = compilers.registered_compilers # Ensure cache is exists.
compilers.__dict__["registered_compilers"][mock_compiler.ext] = mock_compiler

# Ensure our coverage tracker is using our new tmp project w/ the new src
# as well is set _after_ our new compiler plugin is added.
tracker = CoverageTracker(pytest_config, project=tmp)

tracker.cover(call_tb, contract=filestem, function=f"{fn_name}()")
assert mock_statement.hit_count > 0

finally:
if (
"registered_compilers" in compilers.__dict__
and mock_compiler.ext in compilers.__dict__["registered_compilers"]
):
del compilers.__dict__["registered_compilers"][mock_compiler.ext]

0 comments on commit 5665aba

Please sign in to comment.