Skip to content

Commit

Permalink
fix: issue where @network_option with ConnectedProviderCommand us…
Browse files Browse the repository at this point in the history
…ed different networks. (#1796)
  • Loading branch information
antazoey authored Dec 21, 2023
1 parent c255427 commit d80dcc5
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 62 deletions.
86 changes: 37 additions & 49 deletions src/ape/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@ def get_param_from_ctx(ctx: Context, param: str) -> Optional[Any]:

def parse_network(ctx: Context) -> ProviderContextManager:
interactive = get_param_from_ctx(ctx, "interactive")

# Handle if already parsed (as when using network-option)
if ctx.obj and "provider" in ctx.obj:
provider = ctx.obj["provider"]
return provider.network.use_provider(provider, disconnect_on_exit=not interactive)

provider = get_param_from_ctx(ctx, "network")
if provider is not None and isinstance(provider, ProviderAPI):
return provider.network.use_provider(provider, disconnect_on_exit=not interactive)
Expand Down Expand Up @@ -76,55 +82,37 @@ def parse_args(self, ctx: Context, args: List[str]) -> List[str]:
return super().parse_args(ctx, args)

def invoke(self, ctx: Context) -> Any:
with parse_network(ctx):
if self.callback is not None:
opt_name = "network"
param = ctx.params.pop(opt_name, None)
if param is None:
ecosystem = networks.default_ecosystem
network = ecosystem.default_network
# Use default
if default_provider := network.default_provider:
provider = default_provider
else:
# Unlikely to get here.
raise ValueError(
f"Missing default provider for network '{network.choice}'. "
f"Using 'ethereum:local:test'."
)

elif isinstance(param, ProviderAPI):
provider = param

elif isinstance(param, str):
# Is a choice str
provider = networks.parse_network_choice(param)._provider
else:
raise TypeError(f"Can't handle type of parameter '{param}'.")

valid_fields = ("ecosystem", "network", "provider")
requested_fields = [
x for x in inspect.signature(self.callback).parameters if x in valid_fields
]
if self._use_cls_types and requested_fields:
options = {
"ecosystem": provider.network.ecosystem,
"network": provider.network,
"provider": provider,
}
for name in requested_fields:
if (
name not in ctx.params
or ctx.params[name] is None
or isinstance(ctx.params[name], str)
):
ctx.params[name] = options[name]

elif not self._use_cls_types:
# Legacy behavior, but may have a purpose.
ctx.params[opt_name] = provider.network_choice

return ctx.invoke(self.callback, **ctx.params)
with parse_network(ctx) as provider:
if self.callback is None:
return

# Will be put back with correct value if needed.
# Else, causes issues.
ctx.params.pop("network", None)

valid_fields = ("ecosystem", "network", "provider")
requested_fields = [
x for x in inspect.signature(self.callback).parameters if x in valid_fields
]
if self._use_cls_types and requested_fields:
options = {
"ecosystem": provider.network.ecosystem,
"network": provider.network,
"provider": provider,
}
for name in requested_fields:
if (
name not in ctx.params
or ctx.params[name] is None
or isinstance(ctx.params[name], str)
):
ctx.params[name] = options[name]

elif not self._use_cls_types:
# Legacy behavior, but may have a purpose.
ctx.params["network"] = provider.network_choice

return ctx.invoke(self.callback, **ctx.params)


# TODO: 0.8 delete
Expand Down
24 changes: 22 additions & 2 deletions src/ape/cli/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from ape import networks, project
from ape.api import ProviderAPI
from ape.cli import ConnectedProviderCommand
from ape.cli.choices import (
_ACCOUNT_TYPE_FILTER,
AccountAliasPromptChoice,
Expand All @@ -22,7 +23,7 @@
_VERBOSITY_VALUES = ("--verbosity", "-v")


class ApeCliContextObject(ManagerAccessMixin):
class ApeCliContextObject(ManagerAccessMixin, dict):
"""
A ``click`` context object class. Use via :meth:`~ape.cli.options.ape_cli_context()`.
It provides common CLI utilities for ape, such as logging or
Expand All @@ -31,6 +32,7 @@ class ApeCliContextObject(ManagerAccessMixin):

def __init__(self):
self.logger = logger
super().__init__({})

@staticmethod
def abort(msg: str, base_error: Optional[Exception] = None) -> NoReturn:
Expand Down Expand Up @@ -232,11 +234,29 @@ def callback(ctx, param, value):
"provider": provider_obj,
}

# Set the actual values.
# Set the actual values in the callback.
for item in requested_network_objects:
instance = choice_classes[item]
ctx.params[item] = instance

if isinstance(ctx.command, ConnectedProviderCommand):
# Place all values, regardless of request in
# the context. This helps the Ape CLI backend.
if ctx.obj is None:
# Happens when using commands that don't use the
# Ape context or any context.
ctx.obj = {}

for choice, obj in choice_classes.items():
try:
ctx.obj[choice] = obj
except Exception:
# This would only happen if using an unusual context object.
raise Abort(
"Cannot use connected-provider command type(s) "
"with non key-settable context object."
)

# else: provider is None, meaning not connected intentionally.

return value if user_callback is None else user_callback(ctx, param, value)
Expand Down

This file was deleted.

27 changes: 17 additions & 10 deletions tests/functional/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ape.cli.commands import get_param_from_ctx, parse_network
from ape.exceptions import AccountsError
from ape.logging import logger
from tests.conftest import geth_process_test

OUTPUT_FORMAT = "__TEST__{0}:{1}:{2}_"
OTHER_OPTION_VALUE = "TEST_OTHER_OPTION"
Expand Down Expand Up @@ -513,15 +514,7 @@ def use_net(network, other):
def solo_other(other):
click.echo(other)

# Scenario: Option explicit (shouldn't matter)
@click.command(cls=ConnectedProviderCommand)
@network_option()
@other_option
def explicit_option(other):
click.echo(other)

@click.command(cls=ConnectedProviderCommand)
@network_option()
@click.argument("other_arg")
@other_option
def with_arg(other_arg, other, provider):
Expand All @@ -544,14 +537,28 @@ def run(cmd, extra_args=None):
result = run(solo_other)
assert "local" not in result.output, result.output

run(explicit_option)

argument = "_extra_"
result = run(with_arg, extra_args=[argument])
assert "test" in result.output
assert argument in result.output


@geth_process_test
def test_network_option_with_connected_provider_command(runner, geth_provider):
_ = geth_provider # Ensure already running, to avoid clashing later on.

@click.command(cls=ConnectedProviderCommand)
@network_option()
def cmd(provider):
click.echo(provider.name)

# NOTE: Must use a network that is not the default.
spec = ("--network", "ethereum:local:geth")
res = runner.invoke(cmd, spec, catch_exceptions=False)
assert res.exit_code == 0, res.output
assert "geth" in res.output


# TODO: Delete for 0.8.
def test_deprecated_network_bound_command(runner):
with pytest.warns(
Expand Down

0 comments on commit d80dcc5

Please sign in to comment.