Skip to content

Commit

Permalink
Remote analysis (#1512)
Browse files Browse the repository at this point in the history
- Allows get remote trade events across an entire pool.
- `get_trade_events` now accepts a list of pools to filter on.
- Adding names to `get_trade_event` calls.
- Fixing bug with type checking.
  • Loading branch information
slundqui authored Jun 4, 2024
1 parent 197462b commit aa93345
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 30 deletions.
4 changes: 2 additions & 2 deletions examples/interactive_local_hyperdrive_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@
# Get the raw trade events across all pools.
agent_trade_events = agent0.get_trade_events()
# Alternatively, get events from a single pool.
agent_trade_events = agent0.get_trade_events(pool=hyperdrive0)
agent_trade_events = agent0.get_trade_events(pool=hyperdrive1)
agent_trade_events = agent0.get_trade_events(pool_filter=hyperdrive0)
agent_trade_events = agent0.get_trade_events(pool_filter=hyperdrive1)

# Gets all open positions and their corresponding PNL for an agent across all pools.
agent_positions = agent0.get_positions()
Expand Down
2 changes: 1 addition & 1 deletion examples/interactive_remote_hyperdrive_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@

# Get the raw trade events for the pool.
# Note the pool argument must be provided in remote settings.
agent_trade_events = agent0.get_trade_events(pool=hyperdrive_pool)
agent_trade_events = agent0.get_trade_events(pool_filter=hyperdrive_pool)
# Gets all open positions and their corresponding PNL for an agent for the pool.
agent_positions = agent0.get_positions(pool_filter=hyperdrive_pool)
# Gets all open and closed positions and their corresponding PNL for an agent for the pool.
Expand Down
1 change: 1 addition & 0 deletions src/agent0/chainsync/db/hyperdrive/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
get_latest_block_number_from_pool_info_table,
get_latest_block_number_from_positions_snapshot_table,
get_latest_block_number_from_table,
get_latest_block_number_from_trade_event,
get_pool_config,
get_pool_info,
get_position_snapshot,
Expand Down
15 changes: 9 additions & 6 deletions src/agent0/chainsync/db/hyperdrive/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def get_latest_block_number_from_positions_snapshot_table(
def get_trade_events(
session: Session,
wallet_address: str | list[str] | None = None,
hyperdrive_address: str | None = None,
hyperdrive_address: str | list[str] | None = None,
all_token_deltas: bool = True,
sort_ascending: bool = True,
query_limit: int | None = None,
Expand All @@ -199,8 +199,8 @@ def get_trade_events(
The initialized db session object.
wallet_address: str | list[str] | None, optional
The wallet address(es) to filter the results on. Return all if None.
hyperdrive_address: str | None, optional
The hyperdrive address to filter the results on. Returns all if None.
hyperdrive_address: str | list[str] | None, optional
The hyperdrive address(es) to filter the results on. Returns all if None.
all_token_deltas: bool, optional
When removing liquidity that results in withdrawal shares, the events table returns
two entries for this transaction to keep track of token deltas (one for lp tokens and
Expand Down Expand Up @@ -230,8 +230,11 @@ def get_trade_events(
elif wallet_address is not None:
query = query.filter(TradeEvent.wallet_address == wallet_address)

if hyperdrive_address is not None:
if isinstance(hyperdrive_address, list):
query = query.filter(TradeEvent.hyperdrive_address.in_(hyperdrive_address))
elif hyperdrive_address is not None:
query = query.filter(TradeEvent.hyperdrive_address == hyperdrive_address)

if not all_token_deltas:
# Drop the duplicate events
query = query.filter(
Expand Down Expand Up @@ -595,7 +598,7 @@ def get_all_traders(session: Session, hyperdrive_address: str | None = None) ->
# pylint: disable=too-many-arguments
def get_position_snapshot(
session: Session,
hyperdrive_address: list[str] | str | None = None,
hyperdrive_address: str | list[str] | None = None,
start_block: int | None = None,
end_block: int | None = None,
wallet_address: list[str] | str | None = None,
Expand All @@ -607,7 +610,7 @@ def get_position_snapshot(
---------
session: Session
The initialized session object.
hyperdrive_address: list[str] | str | None, optional
hyperdrive_address: str | list[str] | None, optional
The hyperdrive pool address(es) to filter the query on. Defaults to returning all position snapshots.
start_block: int | None, optional
The starting block to filter the query on. start_block integers
Expand Down
42 changes: 40 additions & 2 deletions src/agent0/core/hyperdrive/interactive/hyperdrive.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
import pandas as pd
from eth_typing import ChecksumAddress

from agent0.chainsync.db.hyperdrive import add_hyperdrive_addr_to_name
from agent0.chainsync.db.hyperdrive import (
add_hyperdrive_addr_to_name,
checkpoint_events_to_db,
get_latest_block_number_from_trade_event,
get_trade_events,
trade_events_to_db,
)
from agent0.ethpy.hyperdrive import (
HyperdriveReadWriteInterface,
generate_name_for_hyperdrive,
Expand Down Expand Up @@ -191,7 +197,33 @@ def get_trade_events(self, all_token_deltas: bool = False, coerce_float: bool =
pd.Dataframe
A dataframe of trade events.
"""
raise NotImplementedError
# pylint: disable=protected-access

# There's a case where a user calls `agent.get_trade_events()` followed by
# `pool.get_trade_events()`. This puts duplicate entries into the same underlying
# table.
# We prevent this by not allowing this call if the underlying table isn't empty
# TODO we can relax this by either dropping any entries from this pool, or by making
# a db update on a unique constraint.

if (
get_latest_block_number_from_trade_event(
self.chain.db_session, hyperdrive_address=self.hyperdrive_address, wallet_address=None
)
!= 0
):
raise NotImplementedError("Can't call `hyperdrive.get_trade_events` after `agent.get_trade_events()`.")

self._sync_events()
out = get_trade_events(
self.chain.db_session,
hyperdrive_address=self.interface.hyperdrive_address,
all_token_deltas=all_token_deltas,
coerce_float=coerce_float,
).drop("id", axis=1)
out = self.chain._add_username_to_dataframe(out, "wallet_address")
out = self.chain._add_hyperdrive_name_to_dataframe(out, "hyperdrive_address")
return out

def get_historical_positions(self, coerce_float: bool = False) -> pd.DataFrame:
"""Gets the history of all positions over time and their corresponding pnl
Expand Down Expand Up @@ -243,3 +275,9 @@ def hyperdrive_address(self) -> ChecksumAddress:
"""
# pylint: disable=protected-access
return self.interface.hyperdrive_address

def _sync_events(self) -> None:
# Remote hyperdrive stack syncs only the agent's wallet
trade_events_to_db([self.interface], wallet_addr=None, db_session=self.chain.db_session)
# We sync checkpoint events as well
checkpoint_events_to_db([self.interface], db_session=self.chain.db_session)
32 changes: 22 additions & 10 deletions src/agent0/core/hyperdrive/interactive/hyperdrive_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,14 +1011,17 @@ def _get_positions(
return position_snapshot

def get_trade_events(
self, pool: Hyperdrive | None = None, all_token_deltas: bool = False, coerce_float: bool = False
self,
pool_filter: Hyperdrive | list[Hyperdrive] | None = None,
all_token_deltas: bool = False,
coerce_float: bool = False,
) -> pd.DataFrame:
"""Returns the agent's current wallet.
Arguments
---------
pool : Hyperdrive | None, optional
The hyperdrive pool to get trade events from.
pool_filter : Hyperdrive | list[Hyperdrive] | None, optional
The hyperdrive pool(s) to get trade events from.
all_token_deltas: bool, optional
When removing liquidity that results in withdrawal shares, the events table returns
two entries for this transaction to keep track of token deltas (one for lp tokens and
Expand All @@ -1033,35 +1036,44 @@ def get_trade_events(
HyperdriveWallet
The agent's current wallet.
"""
if pool is None:
if pool_filter is None:
# TODO get positions on remote chains must pass in pool for now
# Eventually we get the list of pools from registry and track all pools in registry
raise NotImplementedError("Pool must be specified to get trade events.")
self._sync_events(pool)
return self._get_trade_events(all_token_deltas=all_token_deltas, pool=pool, coerce_float=coerce_float)
self._sync_events(pool_filter)
return self._get_trade_events(
all_token_deltas=all_token_deltas, pool_filter=pool_filter, coerce_float=coerce_float
)

def _get_trade_events(
self,
pool: Hyperdrive | None,
pool_filter: Hyperdrive | list[Hyperdrive] | None,
all_token_deltas: bool,
coerce_float: bool,
) -> pd.DataFrame:
"""We call this function in both remote and local agents, as the remote call needs to
do argument checking."""
# If pool is None, we don't filter on hyperdrive address
if pool is None:
if pool_filter is None:
hyperdrive_address = None
elif isinstance(pool_filter, list):
hyperdrive_address = [str(pool.hyperdrive_address) for pool in pool_filter]
else:
hyperdrive_address = pool.interface.hyperdrive_address
hyperdrive_address = pool_filter.interface.hyperdrive_address

return get_trade_events(
trade_events = get_trade_events(
self.chain.db_session,
hyperdrive_address=hyperdrive_address,
wallet_address=self.address,
all_token_deltas=all_token_deltas,
coerce_float=coerce_float,
).drop("id", axis=1)

# Add usernames
trade_events = self.chain._add_username_to_dataframe(trade_events, "wallet_address")
trade_events = self.chain._add_hyperdrive_name_to_dataframe(trade_events, "hyperdrive_address")
return trade_events

# Helper functions for analysis

def _sync_events(self, pool: Hyperdrive | list[Hyperdrive]) -> None:
Expand Down
4 changes: 4 additions & 0 deletions src/agent0/core/hyperdrive/interactive/local_hyperdrive.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,3 +556,7 @@ def _reinit_state_after_load_snapshot(self) -> None:
"""
# Set internal state block number to 0 to enusre it updates
self.interface.last_state_block_number = BlockNumber(0)

def _sync_events(self) -> None:
# Making sure this function isn't called in local_hyperdrive
raise NotImplementedError
34 changes: 25 additions & 9 deletions src/agent0/core/hyperdrive/interactive/local_hyperdrive_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,7 @@ def get_positions(
Arguments
---------
pool_filter: LocalHyperdrive | list[Hyperdrive], optional
pool_filter: LocalHyperdrive | list[LocalHyperdrive], optional
The hyperdrive pool(s) to query. Defaults to None, which will query all pools.
show_closed_positions: bool, optional
Whether to show positions closed positions (i.e., positions with zero balance). Defaults to False.
Expand All @@ -545,21 +545,29 @@ def get_positions(
if registry_address is not None:
raise ValueError("registry_address not used with local agents")
# Explicit type checking
if pool_filter is not None and not isinstance(pool_filter, LocalHyperdrive):
raise TypeError("Pool must be an instance of LocalHyperdrive for a LocalHyperdriveAgent")
if pool_filter is not None:
if isinstance(pool_filter, list):
for pool in pool_filter:
if not isinstance(pool, LocalHyperdrive):
raise TypeError("Pool must be an instance of LocalHyperdrive for a LocalHyperdriveAgent")
elif not isinstance(pool_filter, LocalHyperdrive):
raise TypeError("Pool must be an instance of LocalHyperdrive for a LocalHyperdriveAgent")
return self._get_positions(
pool_filter=pool_filter, show_closed_positions=show_closed_positions, coerce_float=coerce_float
)

def get_trade_events(
self, pool: Hyperdrive | None = None, all_token_deltas: bool = False, coerce_float: bool = False
self,
pool_filter: Hyperdrive | list[Hyperdrive] | None = None,
all_token_deltas: bool = False,
coerce_float: bool = False,
) -> pd.DataFrame:
"""Returns the agent's current wallet.
Arguments
---------
pool : LocalHyperdrive | None, optional
The hyperdrive pool to get trade events from. If None, will retrieve all events from
pool_filter : LocalHyperdrive | list[LocalHyperdrive] | None, optional
The hyperdrive pool(s) to get trade events from. If None, will retrieve all events from
all pools.
all_token_deltas: bool, optional
When removing liquidity that results in withdrawal shares, the events table returns
Expand All @@ -576,9 +584,17 @@ def get_trade_events(
The agent's current wallet.
"""
# Explicit type checking
if pool is not None and not isinstance(pool, LocalHyperdrive):
raise TypeError("Pool must be an instance of LocalHyperdrive for a LocalHyperdriveAgent")
return self._get_trade_events(pool=pool, all_token_deltas=all_token_deltas, coerce_float=coerce_float)
# Explicit type checking
if pool_filter is not None:
if isinstance(pool_filter, list):
for pool in pool_filter:
if not isinstance(pool, LocalHyperdrive):
raise TypeError("Pool must be an instance of LocalHyperdrive for a LocalHyperdriveAgent")
elif not isinstance(pool_filter, LocalHyperdrive):
raise TypeError("Pool must be an instance of LocalHyperdrive for a LocalHyperdriveAgent")
return self._get_trade_events(
pool_filter=pool_filter, all_token_deltas=all_token_deltas, coerce_float=coerce_float
)

def _sync_events(self, pool: Hyperdrive | list[Hyperdrive]) -> None:
# No need to sync in local hyperdrive, we sync when we run the data pipeline
Expand Down

0 comments on commit aa93345

Please sign in to comment.