diff --git a/src/ape/api/accounts.py b/src/ape/api/accounts.py index 331c11c6bb..a3510fdef2 100644 --- a/src/ape/api/accounts.py +++ b/src/ape/api/accounts.py @@ -1,5 +1,6 @@ import os from collections.abc import Iterator +from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING, Any, Optional, Union @@ -23,7 +24,7 @@ ) from ape.logging import logger from ape.types import AddressType, MessageSignature, SignableMessage -from ape.utils import BaseInterfaceModel, abstractmethod +from ape.utils import BaseInterfaceModel, abstractmethod, raises_not_implemented if TYPE_CHECKING: from ape.contracts import ContractContainer, ContractInstance @@ -443,11 +444,11 @@ def accounts(self) -> Iterator[AccountAPI]: Iterator[:class:`~ape.api.accounts.AccountAPI`] """ - @property + @cached_property def data_folder(self) -> Path: """ The path to the account data files. - Defaults to ``$HOME/.ape/`` unless overriden. + Defaults to ``$HOME/.ape/`` unless overridden. """ path = self.config_manager.DATA_FOLDER / self.name path.mkdir(parents=True, exist_ok=True) @@ -573,25 +574,39 @@ class TestAccountContainerAPI(AccountContainerAPI): ``AccountContainerAPI`` directly. Then, they show up in the ``accounts`` test fixture. """ - @property + @cached_property def data_folder(self) -> Path: """ **NOTE**: Test account containers do not touch - persistant data. By default and unless overriden, + persistent data. By default and unless overridden, this property returns the path to ``/dev/null`` and it is not used for anything. """ - if os.name == "posix": - return Path("/dev/null") + return Path("/dev/null" if os.name == "posix" else "NUL") + + @raises_not_implemented + def get_test_account(self, index: int) -> "TestAccountAPI": # type: ignore[empty-body] + """ + Get the test account at the given index. - return Path("NUL") + Args: + index (int): The index of the test account. + + Returns: + :class:`~ape.api.accounts.TestAccountAPI` + """ @abstractmethod - def generate_account(self) -> "TestAccountAPI": + def generate_account(self, index: Optional[int] = None) -> "TestAccountAPI": """ Generate a new test account. """ + def reset(self): + """ + Reset the account container to an original state. + """ + class TestAccountAPI(AccountAPI): """ diff --git a/src/ape/api/providers.py b/src/ape/api/providers.py index de2e5a9cf2..cc2473e753 100644 --- a/src/ape/api/providers.py +++ b/src/ape/api/providers.py @@ -12,7 +12,7 @@ from pathlib import Path from signal import SIGINT, SIGTERM, signal from subprocess import DEVNULL, PIPE, Popen -from typing import Any, Optional, Union, cast +from typing import TYPE_CHECKING, Any, Optional, Union, cast from eth_pydantic_types import HexBytes from ethpm_types.abi import EventABI @@ -42,6 +42,9 @@ raises_not_implemented, ) +if TYPE_CHECKING: + from ape.api.accounts import TestAccountAPI + class BlockAPI(BaseInterfaceModel): """ @@ -659,6 +662,18 @@ def set_balance(self, address: AddressType, amount: int): amount (int): The balance to set in the address. """ + @raises_not_implemented + def get_test_account(self, index: int) -> "TestAccountAPI": # type: ignore[empty-body] + """ + Retrieve one of the provider-generated test accounts. + + Args: + index (int): The index of the test account in the HD-Path. + + Returns: + :class:`~ape.api.accounts.TestAccountAPI` + """ + @log_instead_of_fail(default="") def __repr__(self) -> str: return f"<{self.name.capitalize()} chain_id={self.chain_id}>" diff --git a/src/ape/managers/accounts.py b/src/ape/managers/accounts.py index f01a668333..94e1c78afc 100644 --- a/src/ape/managers/accounts.py +++ b/src/ape/managers/accounts.py @@ -35,6 +35,7 @@ class TestAccountManager(list, ManagerAccessMixin): __test__ = False _impersonated_accounts: dict[AddressType, ImpersonatedAccount] = {} + _accounts_by_index: dict[int, AccountAPI] = {} @log_instead_of_fail(default="") def __repr__(self) -> str: @@ -43,14 +44,13 @@ def __repr__(self) -> str: @cached_property def containers(self) -> dict[str, TestAccountContainerAPI]: - containers = {} - account_types = [ - t for t in self.plugin_manager.account_types if issubclass(t[1][1], TestAccountAPI) - ] - for plugin_name, (container_type, account_type) in account_types: - containers[plugin_name] = container_type(name=plugin_name, account_type=account_type) - - return containers + account_types = filter( + lambda t: issubclass(t[1][1], TestAccountAPI), self.plugin_manager.account_types + ) + return { + plugin_name: container_type(name=plugin_name, account_type=account_type) + for plugin_name, (container_type, account_type) in account_types + } @property def accounts(self) -> Iterator[AccountAPI]: @@ -63,7 +63,7 @@ def aliases(self) -> Iterator[str]: yield account.alias def __len__(self) -> int: - return len(list(self.accounts)) + return sum(len(c) for c in self.containers.values()) def __iter__(self) -> Iterator[AccountAPI]: yield from self.accounts @@ -74,13 +74,16 @@ def __getitem__(self, account_id): @__getitem__.register def __getitem_int(self, account_id: int): + if account_id in self._accounts_by_index: + return self._accounts_by_index[account_id] + + original_account_id = account_id if account_id < 0: account_id = len(self) + account_id - for idx, account in enumerate(self.accounts): - if account_id == idx: - return account - raise IndexError(f"No account at index '{account_id}'.") + account = self.containers["test"].get_test_account(account_id) + self._accounts_by_index[original_account_id] = account + return account @__getitem__.register def __getitem_slice(self, account_id: slice): @@ -136,6 +139,19 @@ def use_sender(self, account_id: Union[TestAccountAPI, AddressType, int]) -> Con account = account_id if isinstance(account_id, TestAccountAPI) else self[account_id] return _use_sender(account) + def init_test_account( + self, index: int, address: AddressType, private_key: str + ) -> "TestAccountAPI": + container = self.containers["test"] + return container.init_test_account( # type: ignore[attr-defined] + index, address, private_key + ) + + def reset(self): + self._accounts_by_index = {} + for container in self.containers.values(): + container.reset() + class AccountManager(BaseManager): """ @@ -168,7 +184,6 @@ def containers(self) -> dict[str, AccountContainerAPI]: Returns: dict[str, :class:`~ape.api.accounts.AccountContainerAPI`] """ - containers = {} data_folder = self.config_manager.DATA_FOLDER data_folder.mkdir(exist_ok=True) @@ -217,7 +232,6 @@ def __len__(self) -> int: Returns: int """ - return sum(len(container) for container in self.containers.values()) def __iter__(self) -> Iterator[AccountAPI]: @@ -291,7 +305,6 @@ def __getitem_int(self, account_id: int) -> AccountAPI: Returns: :class:`~ape.api.accounts.AccountAPI` """ - if account_id < 0: account_id = len(self) + account_id for idx, account in enumerate(self): @@ -366,7 +379,6 @@ def __contains__(self, address: AddressType) -> bool: Returns: bool: ``True`` when the given address is found. """ - return ( any(address in container for container in self.containers.values()) or address in self.test_accounts @@ -381,6 +393,14 @@ def use_sender( account = self[account_id] elif isinstance(account_id, str): # alias account = self.load(account_id) + else: + raise TypeError(account_id) else: account = account_id + return _use_sender(account) + + def init_test_account( + self, index: int, address: AddressType, private_key: str + ) -> "TestAccountAPI": + return self.test_accounts.init_test_account(index, address, private_key) diff --git a/src/ape/managers/project.py b/src/ape/managers/project.py index e2c87ecba0..0982978b7c 100644 --- a/src/ape/managers/project.py +++ b/src/ape/managers/project.py @@ -1887,6 +1887,8 @@ def reconfigure(self, **overrides): self._config_override = overrides _ = self.config + self.account_manager.test_accounts.reset() + def extract_manifest(self) -> PackageManifest: return self.manifest diff --git a/src/ape/utils/testing.py b/src/ape/utils/testing.py index 371b6e97ac..63f9099d00 100644 --- a/src/ape/utils/testing.py +++ b/src/ape/utils/testing.py @@ -48,17 +48,19 @@ def generate_dev_accounts( list[:class:`~ape.utils.GeneratedDevAccount`]: List of development accounts. """ seed = Mnemonic.to_seed(mnemonic) - accounts = [] + hd_path_format = ( + hd_path if "{}" in hd_path or "{0}" in hd_path else f"{hd_path.rstrip('/')}/{{}}" + ) + return [ + _generate_dev_account(hd_path_format, i, seed) + for i in range(start_index, start_index + number_of_accounts) + ] - if "{}" in hd_path or "{0}" in hd_path: - hd_path_format = hd_path - else: - hd_path_format = f"{hd_path.rstrip('/')}/{{}}" - for i in range(start_index, start_index + number_of_accounts): - hd_path_obj = HDPath(hd_path_format.format(i)) - private_key = HexBytes(hd_path_obj.derive(seed)).hex() - address = Account.from_key(private_key).address - accounts.append(GeneratedDevAccount(address, private_key)) - - return accounts +def _generate_dev_account(hd_path, index: int, seed: bytes) -> GeneratedDevAccount: + return GeneratedDevAccount( + address=Account.from_key( + private_key := HexBytes(HDPath(hd_path.format(index)).derive(seed)).hex() + ).address, + private_key=private_key, + ) diff --git a/src/ape_node/provider.py b/src/ape_node/provider.py index 00a36db6a9..2dc19775dc 100644 --- a/src/ape_node/provider.py +++ b/src/ape_node/provider.py @@ -18,7 +18,7 @@ from web3.middleware import geth_poa_middleware from yarl import URL -from ape.api import PluginConfig, SubprocessProvider, TestProviderAPI +from ape.api import PluginConfig, SubprocessProvider, TestAccountAPI, TestProviderAPI from ape.logging import LogLevel, logger from ape.types import SnapshotID from ape.utils.misc import ZERO_ADDRESS, log_instead_of_fail, raises_not_implemented @@ -130,10 +130,10 @@ def __init__( geth_kwargs["dev_mode"] = True hd_path = hd_path or DEFAULT_TEST_HD_PATH - accounts = generate_dev_accounts( + self._dev_accounts = generate_dev_accounts( mnemonic, number_of_accounts=number_of_accounts, hd_path=hd_path ) - addresses = [a.address for a in accounts] + addresses = [a.address for a in self._dev_accounts] addresses.extend(extra_funded_accounts or []) bal_dict = {"balance": str(initial_balance)} alloc = {a: bal_dict for a in addresses} @@ -418,6 +418,17 @@ def mine(self, num_blocks: int = 1): def build_command(self) -> list[str]: return self._process.command if self._process else [] + def get_test_account(self, index: int) -> "TestAccountAPI": + if self._process is None: + # Not managing the process. Use default approach. + test_container = self.account_manager.test_accounts.containers["test"] + return test_container.generate_account(index) + + # perf: we avoid having to generate account keys twice by utilizing + # the accounts generated for geth's genesis.json. + account = self._process._dev_accounts[index] + return self.account_manager.init_test_account(index, account.address, account.private_key) + # NOTE: The default behavior of EthereumNodeBehavior assumes geth. class Node(EthereumNodeProvider): diff --git a/src/ape_test/accounts.py b/src/ape_test/accounts.py index 6ea8f84da0..307b258026 100644 --- a/src/ape_test/accounts.py +++ b/src/ape_test/accounts.py @@ -1,6 +1,6 @@ import warnings from collections.abc import Iterator -from typing import Any, Optional +from typing import Any, Optional, cast from eip712.messages import EIP712Message from eth_account import Account as EthAccount @@ -9,85 +9,84 @@ from eth_utils import to_bytes from ape.api import TestAccountAPI, TestAccountContainerAPI, TransactionAPI -from ape.exceptions import SignatureError +from ape.exceptions import ProviderNotConnectedError, SignatureError from ape.types import AddressType, MessageSignature, TransactionSignature -from ape.utils import GeneratedDevAccount, generate_dev_accounts +from ape.utils import ( + DEFAULT_NUMBER_OF_TEST_ACCOUNTS, + DEFAULT_TEST_HD_PATH, + DEFAULT_TEST_MNEMONIC, + generate_dev_accounts, +) class TestAccountContainer(TestAccountContainerAPI): - num_generated: int = 0 - mnemonic: str = "" - num_of_accounts: int = 0 - hd_path: str = "" - _accounts: list["TestAccount"] = [] + generated_accounts: list["TestAccount"] = [] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.init() - - def init(self): - self.mnemonic = self.config["mnemonic"] - self.num_of_accounts = self.config["number_of_accounts"] - self.hd_path = self.config["hd_path"] - self._accounts = [] - - for index, account in enumerate(self._dev_accounts): - self._accounts.append( - TestAccount( - index=index, address_str=account.address, private_key=account.private_key - ) - ) def __len__(self) -> int: - return self.num_of_accounts + return self.number_of_accounts + len(self.generated_accounts) @property def config(self): return self.config_manager.get_config("test") @property - def _dev_accounts(self) -> list[GeneratedDevAccount]: - return generate_dev_accounts( - self.mnemonic, - number_of_accounts=self.num_of_accounts, - hd_path=self.hd_path, - ) + def mnemonic(self) -> str: + return self.config.get("mnemonic", DEFAULT_TEST_MNEMONIC) @property - def aliases(self) -> Iterator[str]: - for index in range(self.num_of_accounts): - yield f"TEST::{index}" + def number_of_accounts(self) -> int: + return self.config.get("number_of_accounts", DEFAULT_NUMBER_OF_TEST_ACCOUNTS) @property - def _is_config_changed(self): - current_mnemonic = self.config["mnemonic"] - current_number = self.config["number_of_accounts"] - current_hd_path = self.config["hd_path"] - return ( - self.mnemonic != current_mnemonic - or self.num_of_accounts != current_number - or self.hd_path != current_hd_path - ) + def hd_path(self) -> str: + return self.config.get("hd_path", DEFAULT_TEST_HD_PATH) + + @property + def aliases(self) -> Iterator[str]: + for index in range(self.number_of_accounts): + yield f"TEST::{index}" @property def accounts(self) -> Iterator["TestAccount"]: - # As TestAccountManager only uses accounts property this works! - if self._is_config_changed: - self.init() - yield from self._accounts - - def generate_account(self) -> "TestAccountAPI": - new_index = self.num_of_accounts + self.num_generated - self.num_generated += 1 + for index in range(self.number_of_accounts): + yield cast(TestAccount, self.get_test_account(index)) + + def get_test_account(self, index: int) -> TestAccountAPI: + if index >= self.number_of_accounts: + new_index = index - self.number_of_accounts + return self.generated_accounts[new_index] + + try: + return self.provider.get_test_account(index) + except (NotImplementedError, ProviderNotConnectedError): + return self.generate_account(index=index) + + def generate_account(self, index: Optional[int] = None) -> "TestAccountAPI": + new_index = ( + self.number_of_accounts + len(self.generated_accounts) if index is None else index + ) generated_account = generate_dev_accounts( self.mnemonic, 1, hd_path=self.hd_path, start_index=new_index )[0] - acc = TestAccount( - index=new_index, - address_str=generated_account.address, - private_key=generated_account.private_key, + account = self.init_test_account( + new_index, generated_account.address, generated_account.private_key ) - return acc + self.generated_accounts.append(account) + return account + + @classmethod + def init_test_account(cls, index: int, address: AddressType, private_key: str) -> "TestAccount": + return TestAccount( + index=index, + address_str=address, + private_key=private_key, + ) + + def reset(self): + self.generated_accounts = [] class TestAccount(TestAccountAPI): diff --git a/src/ape_test/provider.py b/src/ape_test/provider.py index f69060f799..baf13597a5 100644 --- a/src/ape_test/provider.py +++ b/src/ape_test/provider.py @@ -3,7 +3,7 @@ from collections.abc import Iterator from functools import cached_property from re import Pattern -from typing import Any, Optional, cast +from typing import TYPE_CHECKING, Any, Optional, cast from eth.exceptions import HeaderNotFound from eth_pydantic_types import HexBytes @@ -27,10 +27,13 @@ VirtualMachineError, ) from ape.logging import logger -from ape.types import BlockID, ContractLog, LogFilter, SnapshotID +from ape.types import AddressType, BlockID, ContractLog, LogFilter, SnapshotID from ape.utils import DEFAULT_TEST_CHAIN_ID, DEFAULT_TEST_HD_PATH, gas_estimation_error_message from ape_ethereum.provider import Web3Provider +if TYPE_CHECKING: + from ape.api.accounts import TestAccountAPI + class EthTesterProviderConfig(PluginConfig): chain_id: int = DEFAULT_TEST_CHAIN_ID @@ -334,6 +337,23 @@ def get_contract_logs(self, log_filter: LogFilter) -> Iterator[ContractLog]: ) yield from self.network.ecosystem.decode_logs(log_gen, *log_filter.events) + def get_test_account(self, index: int) -> "TestAccountAPI": + # NOTE: No need to cache here because it happens at the TestAccountManager already. + try: + private_key = self.evm_backend.account_keys[index] + except IndexError as err: + raise IndexError(f"No account at index '{index}'") from err + + address = private_key.public_key.to_canonical_address() + return self.account_manager.init_test_account( + index, + cast(AddressType, f"0x{address.hex()}"), + private_key.to_hex(), + ) + + def add_account(self, private_key: str): + self.evm_backend.add_account(private_key) + def get_virtual_machine_error(self, exception: Exception, **kwargs) -> VirtualMachineError: if isinstance(exception, ValidationError): match = self._CANNOT_AFFORD_GAS_PATTERN.match(str(exception)) diff --git a/tests/functional/test_accounts.py b/tests/functional/test_accounts.py index 4072a73c1a..96728176ed 100644 --- a/tests/functional/test_accounts.py +++ b/tests/functional/test_accounts.py @@ -368,7 +368,10 @@ def test_accounts_splice_access(test_accounts): assert b == test_accounts[1] c = test_accounts[-1] assert c == test_accounts[len(test_accounts) - 1] - assert len(test_accounts[::2]) == len(test_accounts) / 2 + expected = ( + (len(test_accounts) // 2) if len(test_accounts) % 2 == 0 else (len(test_accounts) // 2 + 1) + ) + assert len(test_accounts[::2]) == expected def test_accounts_address_access(owner, accounts): @@ -576,15 +579,15 @@ def test_account_comparison_to_non_account(core_account): def test_create_account(test_accounts): length_at_start = len(test_accounts) - created_acc = test_accounts.generate_test_account() + created_account = test_accounts.generate_test_account() - assert isinstance(created_acc, TestAccount) - assert created_acc.index == length_at_start + assert isinstance(created_account, TestAccount) + assert created_account.index == length_at_start - second_created_acc = test_accounts.generate_test_account() + second_created_account = test_accounts.generate_test_account() - assert created_acc.address != second_created_acc.address - assert second_created_acc.index == created_acc.index + 1 + assert created_account.address != second_created_account.address + assert second_created_account.index == created_account.index + 1 def test_dir(core_account): @@ -610,35 +613,43 @@ def test_is_not_contract(owner, keyfile_account): assert not keyfile_account.is_contract -def test_using_different_hd_path(test_accounts, project): +def test_using_different_hd_path(test_accounts, project, eth_tester_provider): test_config = { "test": { - "hd_path": "m/44'/60'/0'/{}", + "hd_path": "m/44'/60'/0/0", } } - old_first_account = test_accounts[0] + old_address = test_accounts[0].address + original_settings = eth_tester_provider.settings.model_dump(by_alias=True) with project.temp_config(**test_config): - new_first_account = test_accounts[0] - assert old_first_account.address != new_first_account.address + eth_tester_provider.update_settings(test_config["test"]) + new_address = test_accounts[0].address + eth_tester_provider.update_settings(original_settings) + assert old_address != new_address -def test_using_random_mnemonic(test_accounts, project): - test_config = { - "test": { - "mnemonic": "test_mnemonic_for_ape", - } - } - old_first_account = test_accounts[0] +def test_using_random_mnemonic(test_accounts, project, eth_tester_provider): + mnemonic = "candy maple cake sugar pudding cream honey rich smooth crumble sweet treat" + test_config = {"test": {"mnemonic": mnemonic}} + + old_address = test_accounts[0].address + original_settings = eth_tester_provider.settings.model_dump(by_alias=True) with project.temp_config(**test_config): - new_first_account = test_accounts[0] - assert old_first_account.address != new_first_account.address + eth_tester_provider.update_settings(test_config["test"]) + new_address = test_accounts[0].address + + eth_tester_provider.update_settings(original_settings) + assert old_address != new_address def test_iter_test_accounts(test_accounts): - actual = list(iter(test_accounts)) - assert len(actual) == len(test_accounts) + test_accounts.reset() + accounts = list(iter(test_accounts)) + actual = len(accounts) + expected = len(test_accounts) + assert actual == expected def test_declare(contract_container, sender):