Skip to content

Commit

Permalink
Only create checkpoints if a position reaches maturity on this checkp…
Browse files Browse the repository at this point in the history
…oint time (#1696)
  • Loading branch information
slundqui authored Oct 2, 2024
1 parent 3ab7d29 commit bdd1cad
Show file tree
Hide file tree
Showing 6 changed files with 157 additions and 134 deletions.
92 changes: 32 additions & 60 deletions scripts/checkpoint_bots.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,15 @@

from eth_account.account import Account
from eth_account.signers.local import LocalAccount
from eth_typing import ChecksumAddress
from fixedpointmath import FixedPoint
from hyperdrivetypes import IHyperdriveContract
from web3 import Web3
from web3.types import Nonce

from agent0 import Chain, Hyperdrive
from agent0.core.base.make_key import make_private_key
from agent0.ethpy.base import get_account_balance, smart_contract_preview_transaction, smart_contract_transact
from agent0.ethpy.hyperdrive import get_hyperdrive_pool_config, get_hyperdrive_registry_from_artifacts
from agent0.ethpy.hyperdrive.interface._event_logs import EARLIEST_BLOCK_LOOKUP
from agent0.ethpy.base import get_account_balance
from agent0.ethpy.hyperdrive import get_hyperdrive_registry_from_artifacts
from agent0.hyperlogs.rollbar_utilities import initialize_rollbar, log_rollbar_exception, log_rollbar_message

# Checkpoint bot has a lot going on
Expand Down Expand Up @@ -119,9 +117,8 @@ def does_checkpoint_exist(hyperdrive_contract: IHyperdriveContract, checkpoint_t

async def run_checkpoint_bot(
chain: Chain,
pool_address: ChecksumAddress,
pool: Hyperdrive,
sender: LocalAccount,
pool_name: str,
block_time: int = 1,
block_timestamp_interval: int = 1,
check_checkpoint: bool = False,
Expand All @@ -134,12 +131,10 @@ async def run_checkpoint_bot(
---------
chain: Chain
The chain object.
pool_address: ChecksumAddress
The pool address.
pool: Hyperdrive
The Hyperdrive pool object.
sender: LocalAccount
The sender of the transaction.
pool_name: str
The name of the pool from `get_hyperdrive_addresses_from_registry`. Only used in logging.
block_time: int
The block time in seconds.
block_timestamp_interval: int
Expand All @@ -158,13 +153,12 @@ async def run_checkpoint_bot(
# TODO pull this function out and put into agent0
web3 = chain._web3 # pylint: disable=protected-access

hyperdrive_contract: IHyperdriveContract = IHyperdriveContract.factory(w3=web3)(pool_address)

# Run the checkpoint bot. This bot will attempt to mint a new checkpoint
# every checkpoint after a waiting period. It will poll very infrequently
# to reduce the probability of needing to mint a checkpoint.
config = get_hyperdrive_pool_config(hyperdrive_contract)
checkpoint_duration = config.checkpoint_duration
pool_state = pool.interface.get_hyperdrive_state()
checkpoint_duration = pool_state.pool_config.checkpoint_duration
hyperdrive_contract = pool.interface.hyperdrive_contract

# Rollbar assumes any number longer than 2 integers is "data" and groups them together.
# We want to ensure that the pool name is always in different groups, so we add
Expand All @@ -174,7 +168,7 @@ async def run_checkpoint_bot(
# `RETHHyperdrive_30day` -> `RETHHyperdrive_3_0_day`
# TODO this might be done better on the rollbar side with creating a grouping fingerprint
# TODO ERC4626 gets split up here, may want to only do this for the position duration string.
pool_name = "".join([c + "_" if c.isdigit() else c for c in pool_name])
pool_name = "".join([c + "_" if c.isdigit() else c for c in pool.name])

fail_count = 0

Expand All @@ -185,7 +179,7 @@ async def run_checkpoint_bot(
break

# We check for low funds in checkpoint bot
chain_id = chain._web3.eth.chain_id # pylint: disable=protected-access
chain_id = web3.eth.chain_id
checkpoint_bot_eth_balance = FixedPoint(scaled_value=get_account_balance(web3, sender.address))
if checkpoint_bot_eth_balance <= CHECKPOINT_BOT_LOW_ETH_THRESHOLD.get(
chain_id, DEFAULT_CHECKPOINT_BOT_LOW_ETH_THRESHOLD
Expand Down Expand Up @@ -217,25 +211,20 @@ async def run_checkpoint_bot(
logging.info(logging_str)

# Check to see if the pool is paused. We don't run checkpoint bots on this pool if it's paused.
pause_events = hyperdrive_contract.events.PauseStatusUpdated.get_logs(
from_block=EARLIEST_BLOCK_LOOKUP.get(chain_id, "earliest")
)
is_paused = False
if len(list(pause_events)) > 0:
# Get the latest pause event
# TODO get_logs likely returns events in an ordered
# fashion, but we iterate and find the latest one
# just in case
latest_pause_event = None
max_block_number = 0
for event in pause_events:
if event["blockNumber"] > max_block_number:
max_block_number = event["blockNumber"]
latest_pause_event = event
assert latest_pause_event is not None
is_paused = latest_pause_event["args"]["isPaused"]

if enough_time_has_elapsed and checkpoint_doesnt_exist and not is_paused:
is_paused = pool.interface.get_pool_is_paused()

# We look at the total supply of longs/shorts with a maturity time at this checkpoint
positions_matured_on_this_checkpoint = (
pool.interface.get_long_total_supply(checkpoint_time)
+ pool.interface.get_short_total_supply(checkpoint_time)
) > FixedPoint(0)

if (
enough_time_has_elapsed
and checkpoint_doesnt_exist
and positions_matured_on_this_checkpoint
and not is_paused
):
logging_str = f"Pool {pool_name} for {checkpoint_time=}: submitting checkpoint"
logging.info(logging_str)

Expand All @@ -249,24 +238,10 @@ async def run_checkpoint_bot(
# will need to make this more robust so that we retry this
# transaction if the transaction gets stuck.
try:
# 0 is the max iterations for distribute excess idle, where it will default to
# the default max iterations
fn_args = (checkpoint_time, 0)

# Try preview call
_ = smart_contract_preview_transaction(
hyperdrive_contract,
sender.address,
"checkpoint",
*fn_args,
)

receipt = smart_contract_transact(
web3,
hyperdrive_contract,
receipt = pool.interface.create_checkpoint(
sender,
"checkpoint",
*fn_args,
checkpoint_time,
preview=True,
nonce_func=partial(async_get_nonce, web3, sender),
)
# Reset fail count on successful transaction
Expand Down Expand Up @@ -298,7 +273,7 @@ async def run_checkpoint_bot(
continue
logging_str = (
f"{chain.name}: Pool {pool_name} for {checkpoint_time=}: "
f"Checkpoint successfully mined with transaction_hash={receipt['transactionHash'].hex()}"
f"Checkpoint successfully mined with transaction_hash={receipt.transaction_hash}"
)
logging.info(logging_str)
if log_to_rollbar:
Expand All @@ -308,8 +283,6 @@ async def run_checkpoint_bot(
)

if check_checkpoint:
# TODO: Add crash report
assert receipt["status"] == 1, "Checkpoint failed."
latest_block = chain.block_data()
timestamp = latest_block.get("timestamp", None)
if timestamp is None:
Expand Down Expand Up @@ -414,12 +387,12 @@ async def main(argv: Sequence[str] | None = None) -> None:
while True:
logging.info("Checking for new pools...")
# Reset hyperdrive objs
deployed_pools = Hyperdrive.get_hyperdrive_addresses_from_registry(chain, registry_address)
deployed_pools = Hyperdrive.get_hyperdrive_pools_from_registry(chain, registry_address)

# pylint: disable=protected-access
checkpoint_bot_eth_balance = FixedPoint(scaled_value=get_account_balance(chain._web3, sender.address))
log_message = (
f"{chain.name}: Running checkpoint bots for pools {list(deployed_pools.keys())}. "
f"{chain.name}: Running checkpoint bots for pools {[p.name for p in deployed_pools]}. "
f"{checkpoint_bot_eth_balance=}"
)
logging.info(log_message)
Expand All @@ -433,15 +406,14 @@ async def main(argv: Sequence[str] | None = None) -> None:
*[
run_checkpoint_bot(
chain=chain,
pool_address=pool_addr,
pool=pool,
sender=sender,
pool_name=pool_name,
block_time=block_time,
block_timestamp_interval=block_timestamp_interval,
block_to_exit=block_to_exit,
log_to_rollbar=log_to_rollbar,
)
for pool_name, pool_addr in deployed_pools.items()
for pool in deployed_pools
],
return_exceptions=False,
)
Expand Down
38 changes: 38 additions & 0 deletions src/agent0/ethpy/hyperdrive/interface/_contract_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,40 @@ def _get_gov_fees_accrued(
return FixedPoint(scaled_value=gov_fees_accrued)


def _get_long_total_supply(
hyperdrive_contract: IHyperdriveContract,
maturity_time: int,
block_identifier: BlockIdentifier | None,
) -> FixedPoint:
"""See API for documentation."""
if block_identifier is None:
block_identifier = "latest"
asset_id = encode_asset_id(AssetIdPrefix.LONG, maturity_time)
total_supply = hyperdrive_contract.functions.totalSupply(asset_id).call(block_identifier=block_identifier)
return FixedPoint(scaled_value=total_supply)


def _get_short_total_supply(
hyperdrive_contract: IHyperdriveContract,
maturity_time: int,
block_identifier: BlockIdentifier | None,
) -> FixedPoint:
"""See API for documentation."""
if block_identifier is None:
block_identifier = "latest"
asset_id = encode_asset_id(AssetIdPrefix.SHORT, maturity_time)
total_supply = hyperdrive_contract.functions.totalSupply(asset_id).call(block_identifier=block_identifier)
return FixedPoint(scaled_value=total_supply)


def _create_checkpoint(
interface: HyperdriveReadWriteInterface,
sender: LocalAccount,
checkpoint_time: int | None = None,
preview: bool = False,
gas_limit: int | None = None,
write_retry_count: int | None = None,
nonce_func: Callable[[], Nonce] | None = None,
) -> CreateCheckpoint:
"""See API for documentation."""

Expand All @@ -173,6 +201,15 @@ def _create_checkpoint(
# 0 is the max iterations for distribute excess idle, where it will default to
# the default max iterations
fn_args = (checkpoint_time, 0)

if preview:
_ = smart_contract_preview_transaction(
interface.hyperdrive_contract,
sender.address,
"checkpoint",
*fn_args,
)

tx_receipt = smart_contract_transact(
interface.web3,
interface.hyperdrive_contract,
Expand All @@ -183,6 +220,7 @@ def _create_checkpoint(
write_retry_count=write_retry_count,
timeout=interface.txn_receipt_timeout,
txn_options_gas=gas_limit,
nonce_func=nonce_func,
)
trade_result = parse_logs_to_event(tx_receipt, interface, "createCheckpoint")
return trade_result
Expand Down
38 changes: 38 additions & 0 deletions src/agent0/ethpy/hyperdrive/interface/read_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
_get_gov_fees_accrued,
_get_hyperdrive_base_balance,
_get_hyperdrive_eth_balance,
_get_long_total_supply,
_get_short_total_supply,
_get_total_supply_withdrawal_shares,
_get_variable_rate,
_get_vault_shares,
Expand Down Expand Up @@ -650,6 +652,42 @@ def get_gov_fees_accrued(self, block_identifier: BlockIdentifier | None = None)
"""
return _get_gov_fees_accrued(self.hyperdrive_contract, block_identifier)

def get_long_total_supply(self, maturity_time: int, block_identifier: BlockIdentifier | None = None) -> FixedPoint:
"""Get the total supply of long tokens with the given maturity time.
Arguments
---------
maturity_time: int
The maturity time in seconds.
block_identifier: BlockIdentifier, optional
The identifier for a block.
Defaults to the current block number.
Returns
-------
FixedPoint
The result of the total supply of long tokens with the given maturity time.
"""
return _get_long_total_supply(self.hyperdrive_contract, maturity_time, block_identifier)

def get_short_total_supply(self, maturity_time: int, block_identifier: BlockIdentifier | None = None) -> FixedPoint:
"""Get the total supply of short tokens with the given maturity time.
Arguments
---------
maturity_time: int
The maturity time in seconds.
block_identifier: BlockIdentifier, optional
The identifier for a block.
Defaults to the current block number.
Returns
-------
FixedPoint
The result of the total supply of short tokens with the given maturity time.
"""
return _get_short_total_supply(self.hyperdrive_contract, maturity_time, block_identifier)

def get_pause_events(
self,
from_block: BlockIdentifier | None = None,
Expand Down
39 changes: 39 additions & 0 deletions src/agent0/ethpy/hyperdrive/interface/read_interface_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from hyperdrivetypes.fixedpoint_types import FeesFP
from web3.constants import ADDRESS_ZERO

from agent0 import LocalChain, LocalHyperdrive
from agent0.utils.conversions import pool_config_to_fixedpoint, pool_info_to_fixedpoint

if TYPE_CHECKING:
Expand Down Expand Up @@ -304,3 +305,41 @@ def test_deployed_values(self, hyperdrive_read_interface_fixture: HyperdriveRead
# TODO there are rounding errors between api spot price and fixed rates
assert abs(api_spot_price - expected_spot_price) <= FixedPoint(1e-16)
assert abs(api_fixed_rate - expected_fixed_rate) <= FixedPoint(1e-16)

def test_long_short_total_supply(self, fast_chain_fixture: LocalChain):
position_duration = 3600
initial_pool_config = LocalHyperdrive.Config(
checkpoint_duration=position_duration, # 1 hour
)
pool = LocalHyperdrive(fast_chain_fixture, initial_pool_config)
agent_0 = fast_chain_fixture.init_agent(base=FixedPoint(100_000), eth=FixedPoint(100), pool=pool)
agent_1 = fast_chain_fixture.init_agent(base=FixedPoint(100_000), eth=FixedPoint(100), pool=pool)

# Advance time beyond the checkpoint
fast_chain_fixture.advance_time(position_duration, create_checkpoints=False)

# Open longs and shorts
agent_0.open_long(base=FixedPoint(111))
agent_1.open_long(base=FixedPoint(222))
agent_0.open_short(bonds=FixedPoint(111))
agent_1.open_short(bonds=FixedPoint(222))

# Advance time beyond the checkpoint
fast_chain_fixture.advance_time(position_duration, create_checkpoints=False)

# Check total supply
checkpoint_id = pool.interface.calc_checkpoint_id()
# Using checkpoint_id as an non-existant token
assert pool.interface.get_long_total_supply(checkpoint_id) == FixedPoint(0)
assert pool.interface.get_short_total_supply(checkpoint_id) == FixedPoint(0)

maturity_time = agent_0.get_longs()[0].maturity_time
assert maturity_time == agent_1.get_longs()[0].maturity_time
assert maturity_time == agent_0.get_shorts()[0].maturity_time
assert maturity_time == agent_1.get_shorts()[0].maturity_time

expected_long_amount = agent_0.get_longs()[0].balance + agent_1.get_longs()[0].balance
expected_short_amount = agent_0.get_shorts()[0].balance + agent_1.get_shorts()[0].balance

assert pool.interface.get_long_total_supply(maturity_time) == expected_long_amount
assert pool.interface.get_short_total_supply(maturity_time) == expected_short_amount
Loading

0 comments on commit bdd1cad

Please sign in to comment.