Skip to content

Commit

Permalink
refactor: use internal variable for fork detection to prevent scope bug
Browse files Browse the repository at this point in the history
  • Loading branch information
fubuloubu committed Oct 27, 2024
1 parent a4efe87 commit b61ba37
Showing 1 changed file with 31 additions and 23 deletions.
54 changes: 31 additions & 23 deletions silverback/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import defaultdict
from datetime import timedelta
from functools import wraps
from typing import Any, Callable
from typing import Any, Awaitable, Callable

from ape.api.networks import LOCAL_NETWORK_NAME
from ape.contracts import ContractEvent, ContractInstance
Expand Down Expand Up @@ -110,17 +110,18 @@ def __init__(self, settings: Settings | None = None):

provider_context = settings.get_provider_context()
# NOTE: This allows using connected ape methods e.g. `Contract`
provider = provider_context.__enter__()
self.provider = provider_context.__enter__()

self.identifier = SilverbackID(
name=settings.BOT_NAME,
network=provider.network.name,
ecosystem=provider.network.ecosystem.name,
network=self.provider.network.name,
ecosystem=self.provider.network.ecosystem.name,
)

# Adjust defaults from connection
if settings.NEW_BLOCK_TIMEOUT is None and (
provider.network.name.endswith("-fork") or provider.network.name == LOCAL_NETWORK_NAME
self.provider.network.name.endswith("-fork")
or self.provider.network.name == LOCAL_NETWORK_NAME
):
settings.NEW_BLOCK_TIMEOUT = int(timedelta(days=1).total_seconds())

Expand All @@ -140,7 +141,7 @@ def __init__(self, settings: Settings | None = None):

self.signer = settings.get_signer()
self.new_block_timeout = settings.NEW_BLOCK_TIMEOUT
self.use_fork = settings.FORK_MODE
self.use_fork = settings.FORK_MODE and not self.provider.network.name.endswith("-fork")

signer_str = f"\n SIGNER={repr(self.signer)}"
new_block_timeout_str = (
Expand All @@ -149,7 +150,9 @@ def __init__(self, settings: Settings | None = None):

network_choice = f"{self.identifier.ecosystem}:{self.identifier.network}"
logger.success(
f'Loaded Silverback Bot:\n NETWORK="{network_choice}"\n FORK_MODE={self.use_fork}'
"Loaded Silverback Bot:\n"
f' NETWORK="{network_choice}"\n'
f" FORK_MODE={self.use_fork}"
f"{signer_str}{new_block_timeout_str}"
)

Expand Down Expand Up @@ -228,6 +231,22 @@ async def __create_snapshot_handler(
last_block_processed=self.state.get("system:last_block_processed", -1),
)

def _with_fork_decorator(self, handler: Callable) -> Callable:
# Trigger worker-side handling using fork network by wrapping handler
fork_context = self.provider.network_manager.fork

@wraps(handler)
async def fork_handler(*args, **kwargs):
with fork_context():
result = handler(*args, **kwargs)

if inspect.isawaitable(result):
return await result

return result

return fork_handler

def broker_task_decorator(
self,
task_type: TaskType,
Expand Down Expand Up @@ -269,7 +288,9 @@ def broker_task_decorator(
raise ContainerTypeMismatchError(task_type, container)

# Register user function as task handler with our broker
def add_taskiq_task(handler: Callable) -> AsyncTaskiqDecoratedTask:
def add_taskiq_task(
handler: Callable[..., Any | Awaitable[Any]]
) -> AsyncTaskiqDecoratedTask:
labels = {"task_type": str(task_type)}

# NOTE: Do *not* do `if container` because that does a `len(container)` call,
Expand All @@ -279,7 +300,7 @@ def add_taskiq_task(handler: Callable) -> AsyncTaskiqDecoratedTask:
# Address is almost a certainty if the container is being used as a filter here.
if not (contract_address := getattr(container.contract, "address", None)):
raise InvalidContainerTypeError(
"Please provider a contract event from a valid contract instance."
"Please provide a contract event from a valid contract instance."
)

labels["contract_address"] = contract_address
Expand All @@ -288,20 +309,7 @@ def add_taskiq_task(handler: Callable) -> AsyncTaskiqDecoratedTask:
self.tasks[task_type].append(TaskData(name=handler.__name__, labels=labels))

if self.use_fork:
from ape import networks # NOTE: Defer import for load speed

# Trigger worker-side handling using fork network by wrapping handler
is_awaitable = inspect.isawaitable(handler)

@wraps(handler)
async def fork_handler(*args, **kwargs):
with networks.fork():
if is_awaitable:
return await handler(*args, **kwargs)
else:
return handler(*args, **kwargs)

handler = fork_handler
handler = self._with_fork_decorator(handler)

return self.broker.register_task(
handler,
Expand Down

0 comments on commit b61ba37

Please sign in to comment.