From 5ef0d8bd2a4cb26a36432cf550e972d6f419f3dd Mon Sep 17 00:00:00 2001 From: Sheng Lundquist Date: Fri, 10 May 2024 16:18:42 -0500 Subject: [PATCH] Adding event queries to db (#1464) This PR is the first of a series of PRs to support multi-pool trading in agent0. In this PR, we move away from exposing the `agent.wallet` object (as we're deprecating doing bookkeeping on the wallet itself in python) in favor of using the `agent.get_positions()` function. This function (1) does a query of the chain to gather events and adds them to a `TradeEvent` db table, and (2) queries from the `TradeEvent` table to get the current positions a wallet has. The `TradeEvent` table handles all events on any hyperdrive tokens (i.e., long/short/lp). There's a bit of overlap with the `WalletDelta` table, with the main exception that the `TradeEvent` table is lazy - the table only gets updated when `agent.get_positions()` gets called, and only with the events from `agent`. In addition, the table handles both trade events (e.g., `OpenLong`) and single transfer trades (e.g., wallet to wallet transfers of tokens). We likely can deprecate the `WalletDelta` table with a special call to gather all trade events from a Hyperdrive pool, which fills the `TradeEvent` table with every wallet that has made a trade on the pool. There are a couple of places that can be optimized. Currently, we query the chain for events for every `get_positions` call (from the latest entry in the db to latest block). Some bookkeeping is needed to e.g., don't get events from the logs if a user calls `get_positions` on the same block. As a temporary fix, we also remove `agent.wallet` from remote chains, and `get_positions` gathers all events from the remote chain each time it's called. This will get fixed once the database is exposed in the underlying chain object, with the remote wallet also using the `TradeEvent` table to gather wallet positions. Final note: the failing test here is fixed in https://github.com/delvtech/agent0/pull/1462. ## Changes - Removing interactive wallet in favor of using `get_positions()`. - Added `TradeEvent` table to database, with supporting ingestion (`trade_events_to_db`) and query (`get_positions_from_db` and `get_trade_events`) interface functions. - Changed all interactive `agent.wallet` calls to `agent.get_positions()` - Renamed `contract_address` to `hyperdrive_address` in `PoolConfig` db table. --- .../interactive_local_hyperdrive_example.py | 6 +- pyproject.toml | 1 + .../chainsync/analysis/data_to_analysis.py | 40 +- .../chainsync/db/hyperdrive/__init__.py | 5 +- .../chainsync/db/hyperdrive/chain_to_db.py | 344 +++++++++++++++++- .../db/hyperdrive/import_export_data.py | 61 +--- .../db/hyperdrive/import_export_data_test.py | 2 +- .../chainsync/db/hyperdrive/interface.py | 286 ++++++++++----- .../chainsync/db/hyperdrive/interface_test.py | 37 +- src/agent0/chainsync/db/hyperdrive/schema.py | 63 +++- .../chainsync/db/hyperdrive/schema_test.py | 8 +- src/agent0/chainsync/df_to_db.py | 47 +++ .../core/hyperdrive/interactive/hyperdrive.py | 21 +- .../interactive/hyperdrive_agent.py | 37 +- .../hyperdrive/interactive/hyperdrive_test.py | 96 +---- .../interactive/local_hyperdrive.py | 66 +++- .../interactive/local_hyperdrive_test.py | 62 ++-- .../hyperfuzz/system_fuzz/run_fuzz_bots.py | 2 +- .../hyperfuzz/unit_fuzz/fuzz_present_value.py | 20 +- .../hyperfuzz/unit_fuzz/fuzz_profit_check.py | 8 +- .../gym_environments/full_hyperdrive_env.py | 8 +- .../gym_environments/simple_hyperdrive_env.py | 2 +- tests/bot_to_db_test.py | 2 +- 23 files changed, 880 insertions(+), 344 deletions(-) create mode 100644 src/agent0/chainsync/df_to_db.py diff --git a/examples/interactive_local_hyperdrive_example.py b/examples/interactive_local_hyperdrive_example.py index 9c96e55b9e..cc63584d9b 100644 --- a/examples/interactive_local_hyperdrive_example.py +++ b/examples/interactive_local_hyperdrive_example.py @@ -58,7 +58,7 @@ open_long_event_2 = hyperdrive_agent0.open_long(FixedPoint(22222)) # View current wallet -print(hyperdrive_agent0.wallet) +print(hyperdrive_agent0.get_positions()) # NOTE these calls are chainwide calls, so all pools connected to this chain gets affected. # Advance time, accepts timedelta or seconds @@ -72,7 +72,7 @@ maturity_time=open_long_event_1.maturity_time, bonds=open_long_event_1.bond_amount ) -agent0_longs = list(hyperdrive_agent0.wallet.longs.values()) +agent0_longs = list(hyperdrive_agent0.get_positions().longs.values()) close_long_event_2 = hyperdrive_agent0.close_long( maturity_time=agent0_longs[0].maturity_time, bonds=agent0_longs[0].balance ) @@ -85,7 +85,7 @@ # LP add_lp_event = hyperdrive_agent2.add_liquidity(base=FixedPoint(44444)) -remove_lp_event = hyperdrive_agent2.remove_liquidity(shares=hyperdrive_agent2.wallet.lp_tokens) +remove_lp_event = hyperdrive_agent2.remove_liquidity(shares=hyperdrive_agent2.get_positions().lp_tokens) # The above trades doesn't result in withdraw shares, but the function below allows you # to withdrawal shares from the pool. diff --git a/pyproject.toml b/pyproject.toml index 524a16fe4b..1a29a5720c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,6 +126,7 @@ exclude = [".venv", ".vscode", "docs"] [tool.isort] line_length = 120 +multi_line_output=3 [tool.ruff] line-length = 120 diff --git a/src/agent0/chainsync/analysis/data_to_analysis.py b/src/agent0/chainsync/analysis/data_to_analysis.py index f4794adb22..104246a609 100644 --- a/src/agent0/chainsync/analysis/data_to_analysis.py +++ b/src/agent0/chainsync/analysis/data_to_analysis.py @@ -1,15 +1,11 @@ """Functions to gather data from postgres, do analysis, and add back into postgres""" -import logging from decimal import Decimal -from typing import Type import numpy as np import pandas as pd -from sqlalchemy import exc from sqlalchemy.orm import Session -from agent0.chainsync.db.base import Base from agent0.chainsync.db.hyperdrive import ( CurrentWallet, PoolAnalysis, @@ -21,6 +17,7 @@ get_transactions, get_wallet_deltas, ) +from agent0.chainsync.df_to_db import df_to_db from agent0.ethpy.hyperdrive import HyperdriveReadInterface from .calc_base_buffer import calc_base_buffer @@ -31,33 +28,6 @@ pd.set_option("display.max_columns", None) -MAX_BATCH_SIZE = 10000 - - -def _df_to_db(insert_df: pd.DataFrame, schema_obj: Type[Base], session: Session): - """Helper function to add a dataframe to a database""" - table_name = schema_obj.__tablename__ - - # dataframe to_sql needs data types from the schema object - dtype = {c.name: c.type for c in schema_obj.__table__.columns} - # Pandas doesn't play nice with types - insert_df.to_sql( - table_name, - con=session.connection(), - if_exists="append", - method="multi", - index=False, - dtype=dtype, # type: ignore - chunksize=MAX_BATCH_SIZE, - ) - # commit the transaction - try: - session.commit() - except exc.DataError as err: - session.rollback() - logging.error("Error on adding %s: %s", table_name, err) - raise err - def calc_total_wallet_delta(wallet_deltas: pd.DataFrame) -> pd.DataFrame: """Calculates total wallet deltas from wallet_delta for every wallet type and position. @@ -188,7 +158,7 @@ def data_to_analysis( # If it doesn't exist, should be an empty dataframe latest_wallet = get_current_wallet(db_session, end_block=start_block, coerce_float=False) current_wallet_df = calc_current_wallet(wallet_deltas_df, latest_wallet) - _df_to_db(current_wallet_df, CurrentWallet, db_session) + df_to_db(current_wallet_df, CurrentWallet, db_session) # calculate pnl through closeout pnl # TODO this function might be slow due to contract call on chain @@ -217,13 +187,13 @@ def data_to_analysis( # TODO do scaling tests to see the limit of this wallet_pnl["pnl"] = pnl_df # Add wallet_pnl to the database - _df_to_db(wallet_pnl, WalletPNL, db_session) + df_to_db(wallet_pnl, WalletPNL, db_session) # Build ticker from wallet delta transactions = get_transactions(db_session, start_block, end_block, coerce_float=False) ticker_df = calc_ticker(wallet_deltas_df, transactions, pool_info) # TODO add ticker to database - _df_to_db(ticker_df, Ticker, db_session) + df_to_db(ticker_df, Ticker, db_session) # We add pool analysis last since this table is what's being used to determine how far the data pipeline is. # Calculate spot price @@ -248,4 +218,4 @@ def data_to_analysis( pool_analysis_df = pd.concat([pool_info["block_number"], spot_price, fixed_rate, base_buffer], axis=1) pool_analysis_df.columns = ["block_number", "spot_price", "fixed_rate", "base_buffer"] - _df_to_db(pool_analysis_df, PoolAnalysis, db_session) + df_to_db(pool_analysis_df, PoolAnalysis, db_session) diff --git a/src/agent0/chainsync/db/hyperdrive/__init__.py b/src/agent0/chainsync/db/hyperdrive/__init__.py index d09561e464..21daced433 100644 --- a/src/agent0/chainsync/db/hyperdrive/__init__.py +++ b/src/agent0/chainsync/db/hyperdrive/__init__.py @@ -1,6 +1,6 @@ """Hyperdrive database utilities.""" -from .chain_to_db import data_chain_to_db, init_data_chain_to_db +from .chain_to_db import data_chain_to_db, init_data_chain_to_db, trade_events_to_db from .convert_data import ( convert_checkpoint_info, convert_hyperdrive_transactions_for_block, @@ -13,6 +13,7 @@ add_pool_config, add_pool_infos, add_transactions, + add_transfer_events, add_wallet_deltas, get_all_traders, get_checkpoint_info, @@ -23,8 +24,10 @@ get_pool_analysis, get_pool_config, get_pool_info, + get_positions_from_db, get_ticker, get_total_wallet_pnl_over_time, + get_trade_events, get_transactions, get_wallet_deltas, get_wallet_pnl, diff --git a/src/agent0/chainsync/db/hyperdrive/chain_to_db.py b/src/agent0/chainsync/db/hyperdrive/chain_to_db.py index b24c00a2b5..fb836dc62c 100644 --- a/src/agent0/chainsync/db/hyperdrive/chain_to_db.py +++ b/src/agent0/chainsync/db/hyperdrive/chain_to_db.py @@ -1,14 +1,19 @@ """Functions for gathering data from the chain and adding it to the db""" from dataclasses import asdict -from datetime import datetime +from datetime import datetime, timezone +from decimal import Decimal +from typing import Any +import numpy as np +import pandas as pd from fixedpointmath import FixedPoint from sqlalchemy.orm import Session -from web3.types import BlockData +from web3.types import BlockData, EventData +from agent0.chainsync.df_to_db import df_to_db from agent0.ethpy.base import fetch_contract_transactions_for_block -from agent0.ethpy.hyperdrive import HyperdriveReadInterface +from agent0.ethpy.hyperdrive import AssetIdPrefix, HyperdriveReadInterface, decode_asset_id from .convert_data import ( convert_checkpoint_info, @@ -16,7 +21,15 @@ convert_pool_config, convert_pool_info, ) -from .interface import add_checkpoint_info, add_pool_config, add_pool_infos, add_transactions, add_wallet_deltas +from .interface import ( + add_checkpoint_info, + add_pool_config, + add_pool_infos, + add_transactions, + add_wallet_deltas, + get_latest_block_number_from_trade_event, +) +from .schema import TradeEvent def init_data_chain_to_db( @@ -33,7 +46,7 @@ def init_data_chain_to_db( The database session """ pool_config_dict = asdict(interface.current_pool_state.pool_config) - pool_config_dict["contract_address"] = interface.hyperdrive_address + pool_config_dict["hyperdrive_address"] = interface.hyperdrive_address fees = pool_config_dict["fees"] pool_config_dict["curve_fee"] = fees["curve"] pool_config_dict["flat_fee"] = fees["flat"] @@ -87,7 +100,7 @@ def data_chain_to_db(interface: HyperdriveReadInterface, block: BlockData, sessi # Adding this last as pool info is what we use to determine if this block is in the db for analysis pool_info_dict = asdict(pool_state.pool_info) pool_info_dict["block_number"] = int(pool_state.block_number) - pool_info_dict["timestamp"] = datetime.utcfromtimestamp(pool_state.block_time) + pool_info_dict["timestamp"] = datetime.fromtimestamp(pool_state.block_time, timezone.utc) # Adding additional fields pool_info_dict["epoch_timestamp"] = pool_state.block_time @@ -100,3 +113,322 @@ def data_chain_to_db(interface: HyperdriveReadInterface, block: BlockData, sessi block_pool_info = convert_pool_info(pool_info_dict) add_pool_infos([block_pool_info], session) + + +def _event_data_to_dict(in_val: EventData) -> dict[str, Any]: + out = dict(in_val) + # The args field is also an attribute dict, change to dict + out["args"] = dict(in_val["args"]) + + # Convert transaction hash to string + out["transactionHash"] = in_val["transactionHash"].to_0x_hex() + # Get token id field from args. + # This field is `assetId` for open/close long/short + return out + + +# TODO cleanup +# pylint: disable=too-many-branches +# pylint: disable=too-many-statements +# pylint: disable=too-many-locals +def trade_events_to_db( + interfaces: list[HyperdriveReadInterface], + wallet_addr: str, + db_session: Session, +) -> None: + """Function to query trade events from all pools and add them to the db. + + Arguments + --------- + interfaces: list[HyperdriveReadInterface] + A collection of Hyperdrive interface objects, each connected to a pool. + wallet_addr: str + The wallet address to query. + db_session: Session + The database session. + """ + assert len(interfaces) > 0 + + # Get the earliest block to get events from + # TODO can narrow this down to the last block we checked + # For now, keep this as the latest entry of this wallet. + # + 1 since the queries are inclusive + from_block = get_latest_block_number_from_trade_event(db_session, wallet_addr) + 1 + + # Gather all events we care about here + all_events = [] + + for interface in interfaces: + events = interface.hyperdrive_contract.events.TransferSingle.get_logs( + fromBlock=from_block, + argument_filters={"to": wallet_addr}, + ) + # Change events from attribute dict to dictionary + all_events.extend([_event_data_to_dict(event) for event in events]) + + events = interface.hyperdrive_contract.events.TransferSingle.get_logs( + fromBlock=from_block, + argument_filters={"from": wallet_addr}, + ) + all_events.extend([_event_data_to_dict(event) for event in events]) + + # Hyperdrive events + events = interface.hyperdrive_contract.events.OpenLong.get_logs( + fromBlock=from_block, + argument_filters={"trader": wallet_addr}, + ) + all_events.extend([_event_data_to_dict(event) for event in events]) + + events = interface.hyperdrive_contract.events.CloseLong.get_logs( + fromBlock=from_block, + argument_filters={"trader": wallet_addr}, + ) + all_events.extend([_event_data_to_dict(event) for event in events]) + + events = interface.hyperdrive_contract.events.OpenShort.get_logs( + fromBlock=from_block, + argument_filters={"trader": wallet_addr}, + ) + all_events.extend([_event_data_to_dict(event) for event in events]) + + events = interface.hyperdrive_contract.events.CloseShort.get_logs( + fromBlock=from_block, + argument_filters={"trader": wallet_addr}, + ) + all_events.extend([_event_data_to_dict(event) for event in events]) + + events = interface.hyperdrive_contract.events.AddLiquidity.get_logs( + fromBlock=from_block, + argument_filters={"provider": wallet_addr}, + ) + all_events.extend([_event_data_to_dict(event) for event in events]) + + events = interface.hyperdrive_contract.events.RemoveLiquidity.get_logs( + fromBlock=from_block, + argument_filters={"provider": wallet_addr}, + ) + all_events.extend([_event_data_to_dict(event) for event in events]) + + events = interface.hyperdrive_contract.events.RedeemWithdrawalShares.get_logs( + fromBlock=from_block, + argument_filters={"provider": wallet_addr}, + ) + all_events.extend([_event_data_to_dict(event) for event in events]) + + # Convert to dataframe + events_df = pd.DataFrame(all_events) + # If no events, we just return + if len(events_df) == 0: + return + + # Each transaction made through hyperdrive has two rows, + # one TransferSingle and one for the trade. + # Any transactions without a corresponding trade is a wallet to wallet transfer. + + # Look for any transfer events not associated with a trade + unique_events_per_transaction = events_df.groupby("transactionHash")["event"].agg(["unique", "nunique"]) + # Sanity check + if (unique_events_per_transaction["nunique"] > 2).any(): + raise ValueError( + "Found more than 2 unique events for transaction." + f"{unique_events_per_transaction[unique_events_per_transaction['nunique'] > 2]['unique']}" + ) + + # Find any transfer events that are not associated with a trade. + # This happens when e.g., a wallet to wallet transfer happens, or + # if this wallet is the initializer of the pool. + # TODO we have a test for initializer of the pool, but we need to implement + # wallet to wallet transfers of tokens in the interactive interface for a full test + transfer_events_trx_hash = unique_events_per_transaction[ + unique_events_per_transaction["nunique"] < 2 + ].reset_index()["transactionHash"] + transfer_events_df = events_df[events_df["transactionHash"].isin(transfer_events_trx_hash)].copy() + if len(transfer_events_df) > 0: + # Expand the args dict without losing the args dict field + # json_normalize works on series, but typing doesn't support it. + args_columns = pd.json_normalize(transfer_events_df["args"]) # type: ignore + transfer_events_df = pd.concat([transfer_events_df, args_columns], axis=1) + # We apply the decode function to each element, then expand the resulting + # tuple to multiple columns + transfer_events_df["token_type"], transfer_events_df["maturityTime"] = zip( + *transfer_events_df["id"].astype(int).apply(decode_asset_id) + ) + # Convert token_type enum to name + transfer_events_df["token_type"] = transfer_events_df["token_type"].apply(lambda x: AssetIdPrefix(x).name) + # Convert maturity times of 0 to nan to match other events + transfer_events_df.loc[transfer_events_df["maturityTime"] == 0, "maturityTime"] = np.nan + # Set token id, default is to set it to the token type + transfer_events_df["token_id"] = transfer_events_df["token_type"] + # Append the maturity time for longs and shorts + long_or_short_idx = transfer_events_df["token_type"].isin(["LONG", "SHORT"]) + transfer_events_df.loc[long_or_short_idx, "token_id"] = ( + transfer_events_df.loc[long_or_short_idx, "token_type"] + + "-" + + transfer_events_df.loc[long_or_short_idx, "maturityTime"].astype(str) + ) + + # Set the trader of this transfer + transfer_events_df["trader"] = wallet_addr + + # See if it's a receive or send of tokens + send_idx = transfer_events_df["from"] == wallet_addr + receive_idx = transfer_events_df["to"] == wallet_addr + # Set the token delta based on send or receive + transfer_events_df.loc[send_idx, "token_delta"] = -transfer_events_df.loc[send_idx, "value"].apply( + lambda x: Decimal(x) / Decimal(1e18) # type: ignore + ) + transfer_events_df.loc[receive_idx, "token_delta"] = transfer_events_df.loc[receive_idx, "value"].apply( + lambda x: Decimal(x) / Decimal(1e18) # type: ignore + ) + # Base delta is always 0 + transfer_events_df["base_delta"] = Decimal(0) + + # Drop all transfer single events + events_df = events_df[events_df["event"] != "TransferSingle"].reset_index(drop=True) + + # Sanity check, one hyperdrive event per transaction hash + if events_df.groupby("transactionHash")["event"].nunique().all() != 1: + raise ValueError("Found more than one event per transaction hash.") + + # Expand the args dict without losing the args dict field + # json_normalize works on series, but typing doesn't support it. + args_columns = pd.json_normalize(events_df["args"]) # type: ignore + events_df = pd.concat([events_df, args_columns], axis=1) + + # Convert fields to db schema + # Longs + events_idx = events_df["event"].isin(["OpenLong", "CloseLong"]) + if events_idx.any(): + events_df.loc[events_idx, "token_type"] = "LONG" + events_df.loc[events_idx, "token_id"] = "LONG-" + events_df.loc[events_idx, "maturityTime"].astype(int).astype( + str + ) + + events_idx = events_df["event"] == "OpenLong" + if events_idx.any(): + # Pandas apply doesn't play nice with types + events_df.loc[events_idx, "token_delta"] = events_df.loc[events_idx, "bondAmount"].apply( + lambda x: Decimal(x) / Decimal(1e18) # type: ignore + ) + events_df.loc[events_idx, "base_delta"] = -events_df.loc[events_idx, "baseAmount"].apply( + lambda x: Decimal(x) / Decimal(1e18) # type: ignore + ) + + events_idx = events_df["event"] == "CloseLong" + if events_idx.any(): + # Pandas apply doesn't play nice with types + events_df.loc[events_idx, "token_delta"] = -events_df.loc[events_idx, "bondAmount"].apply( + lambda x: Decimal(x) / Decimal(1e18) # type: ignore + ) + events_df.loc[events_idx, "base_delta"] = events_df.loc[events_idx, "baseAmount"].apply( + lambda x: Decimal(x) / Decimal(1e18) # type: ignore + ) + + # Shorts + events_idx = events_df["event"].isin(["OpenShort", "CloseShort"]) + if events_idx.any(): + events_df.loc[events_idx, "token_type"] = "SHORT" + events_df.loc[events_idx, "token_id"] = "SHORT-" + events_df.loc[events_idx, "maturityTime"].astype(int).astype( + str + ) + + events_idx = events_df["event"] == "OpenShort" + if events_idx.any(): + # Pandas apply doesn't play nice with types + events_df.loc[events_idx, "token_delta"] = events_df.loc[events_idx, "bondAmount"].apply( + lambda x: Decimal(x) / Decimal(1e18) # type: ignore + ) + events_df.loc[events_idx, "base_delta"] = -events_df.loc[events_idx, "baseAmount"].apply( + lambda x: Decimal(x) / Decimal(1e18) # type: ignore + ) + + events_idx = events_df["event"] == "CloseShort" + if events_idx.any(): + # Pandas apply doesn't play nice with types + events_df.loc[events_idx, "token_delta"] = -events_df.loc[events_idx, "bondAmount"].apply( + lambda x: Decimal(x) / Decimal(1e18) # type: ignore + ) + events_df.loc[events_idx, "base_delta"] = events_df.loc[events_idx, "baseAmount"].apply( + lambda x: Decimal(x) / Decimal(1e18) # type: ignore + ) + + # LP + events_idx = events_df["event"].isin(["AddLiquidity", "RemoveLiquidity"]) + if events_idx.any(): + events_df.loc[events_idx, "token_type"] = "LP" + events_df.loc[events_idx, "token_id"] = "LP" + # The wallet here is the "provider" column, we remap it to "trader" + events_df.loc[events_idx, "trader"] = events_df.loc[events_idx, "provider"] + # We explicitly add a maturity time here to ensure this column exists + # if there were no longs in this event set. + events_df.loc[events_idx, "maturityTime"] = np.nan + + events_idx = events_df["event"] == "AddLiquidity" + if events_idx.any(): + # Pandas apply doesn't play nice with types + events_df.loc[events_idx, "token_delta"] = events_df.loc[events_idx, "lpAmount"].apply( + lambda x: Decimal(x) / Decimal(1e18) # type: ignore + ) + events_df.loc[events_idx, "base_delta"] = -events_df.loc[events_idx, "baseAmount"].apply( + lambda x: Decimal(x) / Decimal(1e18) # type: ignore + ) + + events_idx = events_df["event"] == "RemoveLiquidity" + if events_idx.any(): + # Pandas apply doesn't play nice with types + events_df.loc[events_idx, "token_delta"] = -events_df.loc[events_idx, "lpAmount"].apply( + lambda x: Decimal(x) / Decimal(1e18) # type: ignore + ) + events_df.loc[events_idx, "base_delta"] = events_df.loc[events_idx, "baseAmount"].apply( + lambda x: Decimal(x) / Decimal(1e18) # type: ignore + ) + # We need to also add any withdrawal shares as additional rows + withdrawal_shares_idx = events_idx & (events_df["withdrawalShareAmount"] > 0) + if withdrawal_shares_idx.any(): + withdrawal_rows = events_df[withdrawal_shares_idx].copy() + withdrawal_rows["token_type"] = "WITHDRAWAL_SHARE" + withdrawal_rows["token_id"] = "WITHDRAWAL_SHARE" + withdrawal_rows["token_delta"] = withdrawal_rows["withdrawalShareAmount"].apply( + lambda x: Decimal(x) / Decimal(1e18) # type: ignore + ) + withdrawal_rows["base_delta"] = Decimal(0) + events_df = pd.concat([events_df, withdrawal_rows], axis=0) + + events_idx = events_df["event"] == "RedeemWithdrawalShares" + if events_idx.any(): + events_df.loc[events_idx, "token_type"] = "WITHDRAWAL_SHARE" + events_df.loc[events_idx, "token_id"] = "WITHDRAWAL_SHARE" + # The wallet here is the "provider" column, we remap it to "trader" + events_df.loc[events_idx, "trader"] = events_df.loc[events_idx, "provider"] + # We explicitly add a maturity time here to ensure this column exists + # if there were no longs in this event set. + events_df.loc[events_idx, "maturityTime"] = np.nan + # Pandas apply doesn't play nice with types + events_df.loc[events_idx, "token_delta"] = -events_df.loc[events_idx, "withdrawalShareAmount"].apply( + lambda x: Decimal(x) / Decimal(1e18) # type: ignore + ) + events_df.loc[events_idx, "base_delta"] = events_df.loc[events_idx, "baseAmount"].apply( + lambda x: Decimal(x) / Decimal(1e18) # type: ignore + ) + + # Add solo transfer events to events_df + events_df = pd.concat([events_df, transfer_events_df], axis=0) + + # We select the subset of columns we need and rename to match db schema + rename_dict = { + "address": "hyperdrive_address", + "transactionHash": "transaction_hash", + "blockNumber": "block_number", + "trader": "wallet_address", + "event": "event_type", + "token_type": "token_type", + "maturityTime": "maturity_time", + "token_id": "token_id", + "token_delta": "token_delta", + "base_delta": "base_delta", + } + + events_df = events_df[list(rename_dict.keys())].rename(columns=rename_dict) + + # Add to db + df_to_db(events_df, TradeEvent, db_session) diff --git a/src/agent0/chainsync/db/hyperdrive/import_export_data.py b/src/agent0/chainsync/db/hyperdrive/import_export_data.py index 1e8e9b97c1..4d4f8cfb8f 100644 --- a/src/agent0/chainsync/db/hyperdrive/import_export_data.py +++ b/src/agent0/chainsync/db/hyperdrive/import_export_data.py @@ -4,7 +4,6 @@ import logging import os -from typing import Type import pandas as pd from sqlalchemy import exc @@ -12,12 +11,12 @@ from agent0.chainsync.db.base import ( AddrToUsername, - Base, UsernameToUser, get_addr_to_username, get_username_to_user, initialize_session, ) +from agent0.chainsync.df_to_db import df_to_db from .interface import ( get_checkpoint_info, @@ -26,6 +25,7 @@ get_pool_config, get_pool_info, get_ticker, + get_trade_events, get_transactions, get_wallet_deltas, get_wallet_pnl, @@ -38,12 +38,11 @@ PoolConfig, PoolInfo, Ticker, + TradeEvent, WalletDelta, WalletPNL, ) -MAX_BATCH_SIZE = 10000 - def export_db_to_file(out_dir: str, db_session: Session | None = None, raw: bool = False) -> None: """Export all tables from the database and write as parquet files, one per table. @@ -77,6 +76,9 @@ def export_db_to_file(out_dir: str, db_session: Session | None = None, raw: bool os.path.join(out_dir, "username_to_user.parquet"), index=False, engine="pyarrow" ) + # Agent event tables + get_trade_events(db_session).to_parquet(os.path.join(out_dir, "trade_event.parquet"), index=False, engine="pyarrow") + # Hyperdrive tables get_pool_config(db_session, coerce_float=False).to_parquet( os.path.join(out_dir, "pool_config.parquet"), index=False, engine="pyarrow" @@ -128,6 +130,7 @@ def import_to_pandas(in_dir: str) -> dict[str, pd.DataFrame]: out["addr_to_username"] = pd.read_parquet(os.path.join(in_dir, "addr_to_username.parquet"), engine="pyarrow") out["username_to_user"] = pd.read_parquet(os.path.join(in_dir, "username_to_user.parquet"), engine="pyarrow") + out["trade_event"] = pd.read_parquet(os.path.join(in_dir, "trade_event.parquet"), engine="pyarrow") out["pool_config"] = pd.read_parquet(os.path.join(in_dir, "pool_config.parquet"), engine="pyarrow") out["checkpoint_info"] = pd.read_parquet(os.path.join(in_dir, "checkpoint_info.parquet"), engine="pyarrow") out["pool_info"] = pd.read_parquet(os.path.join(in_dir, "pool_info.parquet"), engine="pyarrow") @@ -157,6 +160,7 @@ def import_to_db(db_session: Session, in_dir: str, drop=True) -> None: if drop: db_session.query(AddrToUsername).delete() db_session.query(UsernameToUser).delete() + db_session.query(TradeEvent).delete() db_session.query(PoolConfig).delete() db_session.query(CheckpointInfo).delete() db_session.query(PoolInfo).delete() @@ -174,40 +178,15 @@ def import_to_db(db_session: Session, in_dir: str, drop=True) -> None: raise err out = import_to_pandas(in_dir) - _df_to_db(out["addr_to_username"], AddrToUsername, db_session) - _df_to_db(out["username_to_user"], UsernameToUser, db_session) - _df_to_db(out["pool_config"], PoolConfig, db_session) - _df_to_db(out["checkpoint_info"], CheckpointInfo, db_session) - _df_to_db(out["pool_info"], PoolInfo, db_session) - _df_to_db(out["wallet_delta"], WalletDelta, db_session) - _df_to_db(out["transactions"], HyperdriveTransaction, db_session) - _df_to_db(out["pool_analysis"], PoolAnalysis, db_session) - _df_to_db(out["current_wallet"], CurrentWallet, db_session) - _df_to_db(out["ticker"], Ticker, db_session) - _df_to_db(out["wallet_pnl"], WalletPNL, db_session) - - -def _df_to_db(insert_df: pd.DataFrame, schema_obj: Type[Base], session: Session): - """Helper function to add a dataframe to a database""" - table_name = schema_obj.__tablename__ - - # dataframe to_sql needs data types from the schema object - dtype = {c.name: c.type for c in schema_obj.__table__.columns} - # Pandas doesn't play nice with types - insert_df.to_sql( - table_name, - con=session.connection(), - method="multi", - # if_exists=if_exists_method, - if_exists="append", - index=False, - dtype=dtype, # type: ignore - chunksize=MAX_BATCH_SIZE, - ) - # commit the transaction - try: - session.commit() - except exc.DataError as err: - session.rollback() - logging.error("Error on adding %s: %s", table_name, err) - raise err + df_to_db(out["addr_to_username"], AddrToUsername, db_session) + df_to_db(out["username_to_user"], UsernameToUser, db_session) + df_to_db(out["trade_event"], TradeEvent, db_session) + df_to_db(out["pool_config"], PoolConfig, db_session) + df_to_db(out["checkpoint_info"], CheckpointInfo, db_session) + df_to_db(out["pool_info"], PoolInfo, db_session) + df_to_db(out["wallet_delta"], WalletDelta, db_session) + df_to_db(out["transactions"], HyperdriveTransaction, db_session) + df_to_db(out["pool_analysis"], PoolAnalysis, db_session) + df_to_db(out["current_wallet"], CurrentWallet, db_session) + df_to_db(out["ticker"], Ticker, db_session) + df_to_db(out["wallet_pnl"], WalletPNL, db_session) diff --git a/src/agent0/chainsync/db/hyperdrive/import_export_data_test.py b/src/agent0/chainsync/db/hyperdrive/import_export_data_test.py index 595341ee90..8ad75fb7c1 100644 --- a/src/agent0/chainsync/db/hyperdrive/import_export_data_test.py +++ b/src/agent0/chainsync/db/hyperdrive/import_export_data_test.py @@ -19,7 +19,7 @@ def test_export_import(self, db_session): """Testing retrieval of transaction via interface""" # Write data to database # Ensuring decimal format gets preserved - pool_config = PoolConfig(contract_address="0", initial_vault_share_price=Decimal("3.22222222222222")) + pool_config = PoolConfig(hyperdrive_address="0", initial_vault_share_price=Decimal("3.22222222222222")) add_pool_config(pool_config, db_session) # We need pool config as a dataframe, so we read it from the db here pool_config_in = get_pool_config(db_session, coerce_float=False) diff --git a/src/agent0/chainsync/db/hyperdrive/interface.py b/src/agent0/chainsync/db/hyperdrive/interface.py index 88947e63ba..dd13ef107a 100644 --- a/src/agent0/chainsync/db/hyperdrive/interface.py +++ b/src/agent0/chainsync/db/hyperdrive/interface.py @@ -19,10 +19,116 @@ PoolConfig, PoolInfo, Ticker, + TradeEvent, WalletDelta, WalletPNL, ) +# Event Data Ingestion Interface + + +def add_transfer_events(transfer_events: list[TradeEvent], session: Session) -> None: + """Add transfer events to the transfer events table. + + Arguments + --------- + transfer_events: list[HyperdriveTransferEvent] + A list of HyperdriveTransferEvent objects to insert into postgres. + session: Session + The initialized session object. + """ + for transfer_event in transfer_events: + session.add(transfer_event) + try: + session.commit() + except exc.DataError as err: + session.rollback() + logging.error("Error adding transaction: %s", err) + raise err + + +def get_latest_block_number_from_trade_event(session: Session, wallet_addr: str) -> int: + """Get the latest block number based on the hyperdrive events table in the db. + + Arguments + --------- + session: Session + The initialized session object. + wallet_addr: str + The wallet address to filter the results on. + + Returns + ------- + int + The latest block number in the hyperdrive_events table. + """ + + query = session.query(func.max(TradeEvent.block_number)).filter(TradeEvent.wallet_address == wallet_addr).scalar() + if query is None: + return 0 + return int(query) + + +def get_trade_events(session: Session, wallet_addr: str | None = None) -> pd.DataFrame: + """Get all trade events and returns a pandas dataframe. + + Arguments + --------- + session: Session + The initialized db session object. + wallet_addr: str | None, optional + The wallet address to filter the results on. Return all if None. + + Returns + ------- + DataFrame + A DataFrame that consists of the queried trade events data. + """ + query = session.query(TradeEvent) + if wallet_addr is not None: + query = query.filter(TradeEvent.wallet_address == wallet_addr) + return pd.read_sql(query.statement, con=session.connection(), coerce_float=False) + + +def get_positions_from_db(session: Session, wallet_addr: str, hyperdrive_address: str | None = None) -> pd.DataFrame: + """Gets all positions for a given wallet address. + + Arguments + --------- + session: Session + The initialized db session object. + wallet_addr: str + The wallet address to filter the results on. + hyperdrive_address: str | None, optional + The hyperdrive address to filter the results on. Returns all if None. + + Returns + ------- + DataFrame + A DataFrame that consists of the queried pool info data. + """ + # TODO also accept and filter by hyperdrive address here + # when we move to multi-pool db support + query = session.query( + TradeEvent.hyperdrive_address, + TradeEvent.wallet_address, + TradeEvent.token_id, + func.sum(TradeEvent.token_delta).label("balance"), + ) + if hyperdrive_address is not None: + query = query.filter( + TradeEvent.wallet_address == wallet_addr, TradeEvent.hyperdrive_address == hyperdrive_address + ) + else: + query = query.filter(TradeEvent.wallet_address == wallet_addr) + query = query.group_by(TradeEvent.hyperdrive_address, TradeEvent.wallet_address, TradeEvent.token_id) + out_df = pd.read_sql(query.statement, con=session.connection(), coerce_float=False) + # Filter out zero balances + return out_df[out_df["balance"] != 0] + + +# Chain To Data Ingestion Interface + def add_transactions(transactions: list[HyperdriveTransaction], session: Session) -> None: """Add transactions to the poolinfo table. @@ -30,9 +136,9 @@ def add_transactions(transactions: list[HyperdriveTransaction], session: Session Arguments --------- transactions: list[HyperdriveTransaction] - A list of HyperdriveTransaction objects to insert into postgres + A list of HyperdriveTransaction objects to insert into postgres. session: Session - The initialized session object + The initialized session object. """ for transaction in transactions: session.add(transaction) @@ -45,25 +151,25 @@ def add_transactions(transactions: list[HyperdriveTransaction], session: Session def get_pool_config(session: Session, contract_address: str | None = None, coerce_float=True) -> pd.DataFrame: - """Get all pool config and returns as a pandas dataframe. + """Get all pool config and returns a pandas dataframe. Arguments --------- session: Session - The initialized session object + The initialized session object. contract_address: str | None, optional - The contract_address to filter the results on. Return all if None + The contract_address to filter the results on. Return all if None. coerce_float: bool - If True, will coerce all numeric columns to float + If True, will coerce all numeric columns to float. Returns ------- DataFrame - A DataFrame that consists of the queried pool config data + A DataFrame that consists of the queried pool config data. """ query = session.query(PoolConfig) if contract_address is not None: - query = query.filter(PoolConfig.contract_address == contract_address) + query = query.filter(PoolConfig.hyperdrive_address == contract_address) return pd.read_sql(query.statement, con=session.connection(), coerce_float=coerce_float) @@ -75,16 +181,16 @@ def add_pool_config(pool_config: PoolConfig, session: Session) -> None: Arguments --------- pool_config: PoolConfig - A PoolConfig object to insert into postgres + A PoolConfig object to insert into postgres. session: Session - The initialized session object + The initialized session object. """ # NOTE the logic below is not thread safe, i.e., a race condition can exists # if multiple threads try to add pool config at the same time # This function is being called by acquire_data.py, which should only have one # instance per db, so no need to worry about it here # Since we're doing a direct equality comparison, we don't want to coerce into floats here - existing_pool_config = get_pool_config(session, contract_address=pool_config.contract_address, coerce_float=False) + existing_pool_config = get_pool_config(session, contract_address=pool_config.hyperdrive_address, coerce_float=False) if len(existing_pool_config) == 0: session.add(pool_config) try: @@ -113,9 +219,9 @@ def add_pool_infos(pool_infos: list[PoolInfo], session: Session) -> None: Arguments --------- pool_infos: list[PoolInfo] - A list of PoolInfo objects to insert into postgres + A list of PoolInfo objects to insert into postgres. session: Session - The initialized session object + The initialized session object. """ for pool_info in pool_infos: session.add(pool_info) @@ -133,9 +239,9 @@ def add_checkpoint_info(checkpoint_info: CheckpointInfo, session: Session) -> No Arguments --------- checkpoint_info: CheckpointInfo - A CheckpointInfo object to insert into postgres + A CheckpointInfo object to insert into postgres. session: Session - The initialized session object + The initialized session object. """ # NOTE the logic below is not thread safe, i.e., a race condition can exists # if multiple threads try to add checkpoint info at the same time @@ -169,9 +275,9 @@ def add_wallet_deltas(wallet_deltas: list[WalletDelta], session: Session) -> Non Arguments --------- wallet_deltas: list[WalletDelta] - A list of WalletDelta objects to insert into postgres + A list of WalletDelta objects to insert into postgres. session: Session - The initialized session object + The initialized session object. """ for wallet_delta in wallet_deltas: session.add(wallet_delta) @@ -189,12 +295,12 @@ def get_latest_block_number_from_pool_info_table(session: Session) -> int: Arguments --------- session: Session - The initialized session object + The initialized session object. Returns ------- int - The latest block number in the poolinfo table + The latest block number in the poolinfo table. """ return get_latest_block_number_from_table(PoolInfo, session) @@ -205,12 +311,12 @@ def get_latest_block_number_from_analysis_table(session: Session) -> int: Arguments --------- session: Session - The initialized session object + The initialized session object. Returns ------- int - The latest block number in the poolinfo table + The latest block number in the poolinfo table. """ return get_latest_block_number_from_table(PoolAnalysis, session) @@ -218,25 +324,25 @@ def get_latest_block_number_from_analysis_table(session: Session) -> int: def get_pool_info( session: Session, start_block: int | None = None, end_block: int | None = None, coerce_float=True ) -> pd.DataFrame: - """Get all pool info and returns as a pandas dataframe. + """Get all pool info and returns a pandas dataframe. Arguments --------- session: Session - The initialized session object + The initialized session object. start_block: int | None, optional The starting block to filter the query on. start_block integers - matches python slicing notation, e.g., list[:3], list[:-3] + matches python slicing notation, e.g., list[:3], list[:-3]. end_block: int | None, optional The ending block to filter the query on. end_block integers - matches python slicing notation, e.g., list[:3], list[:-3] + matches python slicing notation, e.g., list[:3], list[:-3]. coerce_float: bool, optional - If true, will return floats in dataframe. Otherwise, will return fixed point Decimal + If true, will return floats in dataframe. Otherwise, will return fixed point Decimal. Returns ------- DataFrame - A DataFrame that consists of the queried pool info data + A DataFrame that consists of the queried pool info data. """ query = session.query(PoolInfo) @@ -263,25 +369,25 @@ def get_transactions( end_block: int | None = None, coerce_float=True, ) -> pd.DataFrame: - """Get all transactions and returns as a pandas dataframe. + """Get all transactions and returns a pandas dataframe. Arguments --------- session: Session - The initialized session object + The initialized session object. start_block: int | None The starting block to filter the query on. start_block integers - matches python slicing notation, e.g., list[:3], list[:-3] + matches python slicing notation, e.g., list[:3], list[:-3]. end_block: int | None The ending block to filter the query on. end_block integers - matches python slicing notation, e.g., list[:3], list[:-3] + matches python slicing notation, e.g., list[:3], list[:-3]. coerce_float: bool - If true, will return floats in dataframe. Otherwise, will return fixed point Decimal + If true, will return floats in dataframe. Otherwise, will return fixed point Decimal. Returns ------- DataFrame - A DataFrame that consists of the queried transactions data + A DataFrame that consists of the queried transactions data. """ query = session.query(HyperdriveTransaction) @@ -309,16 +415,16 @@ def get_checkpoint_info(session: Session, checkpoint_time: int | None = None, co Arguments --------- session: Session - The initialized session object + The initialized session object. checkpoint_time: int | None, optional The checkpoint time to filter the query on. Defaults to returning all checkpoint infos. coerce_float: bool - If true, will return floats in dataframe. Otherwise, will return fixed point Decimal + If true, will return floats in dataframe. Otherwise, will return fixed point Decimal. Returns ------- DataFrame - A DataFrame that consists of the queried checkpoint info + A DataFrame that consists of the queried checkpoint info. """ query = session.query(CheckpointInfo) @@ -338,27 +444,27 @@ def get_wallet_deltas( return_timestamp: bool = True, coerce_float=True, ) -> pd.DataFrame: - """Get all wallet_delta data in history and returns as a pandas dataframe. + """Get all wallet_delta data in history and returns a pandas dataframe. Arguments --------- session: Session - The initialized session object + The initialized session object. start_block: int | None, optional The starting block to filter the query on. start_block integers - matches python slicing notation, e.g., list[:3], list[:-3] + matches python slicing notation, e.g., list[:3], list[:-3]. end_block: int | None, optional The ending block to filter the query on. end_block integers - matches python slicing notation, e.g., list[:3], list[:-3] + matches python slicing notation, e.g., list[:3], list[:-3]. return_timestamp: bool, optional - Gets timestamps when looking at pool analysis. Defaults to True + Gets timestamps when looking at pool analysis. Defaults to True. coerce_float: bool - If true, will return floats in dataframe. Otherwise, will return fixed point Decimal + If true, will return floats in dataframe. Otherwise, will return fixed point Decimal. Returns ------- DataFrame - A DataFrame that consists of the queried wallet info data + A DataFrame that consists of the queried wallet info data. """ if return_timestamp: query = session.query(PoolInfo.timestamp, WalletDelta) @@ -391,20 +497,20 @@ def get_all_traders( Arguments --------- session: Session - The initialized session object + The initialized session object. start_block: int | None, optional The starting block to filter the query on. start_block integers - matches python slicing notation, e.g., list[:3], list[:-3] + matches python slicing notation, e.g., list[:3], list[:-3]. end_block: int | None, optional The ending block to filter the query on. end_block integers - matches python slicing notation, e.g., list[:3], list[:-3] + matches python slicing notation, e.g., list[:3], list[:-3]. coerce_float: bool - If true, will return floats in dataframe. Otherwise, will return fixed point Decimal + If true, will return floats in dataframe. Otherwise, will return fixed point Decimal. Returns ------- list[str] - A list of addresses that have made a trade + A list of addresses that have made a trade. """ query = session.query(WalletDelta.wallet_address) # Support for negative indices @@ -436,9 +542,9 @@ def add_current_wallet(current_wallet: list[CurrentWallet], session: Session) -> Arguments --------- current_wallet: list[CurrentWallet] - A list of CurrentWallet objects to insert into postgres + A list of CurrentWallet objects to insert into postgres. session: Session - The initialized session object + The initialized session object. """ for wallet in current_wallet: session.add(wallet) @@ -457,26 +563,26 @@ def get_current_wallet( coerce_float=True, raw: bool = False, ) -> pd.DataFrame: - """Get all current wallet data in history and returns as a pandas dataframe. + """Get all current wallet data in history and returns a pandas dataframe. Arguments --------- session: Session - The initialized session object + The initialized session object. end_block: int | None, optional The ending block to filter the query on. end_block integers - matches python slicing notation, e.g., list[:3], list[:-3] + matches python slicing notation, e.g., list[:3], list[:-3]. wallet_address: list[str] | None, optional - The wallet addresses to filter the query on + The wallet addresses to filter the query on. coerce_float: bool - If true, will return floats in dataframe. Otherwise, will return fixed point Decimal + If true, will return floats in dataframe. Otherwise, will return fixed point Decimal. raw: bool If true, will return the raw data without any adjustments. Returns ------- DataFrame - A DataFrame that consists of the queried wallet info data + A DataFrame that consists of the queried wallet info data. """ # TODO this function might not scale, as it's looking across all blocks from the beginning of time # Ways to improve: add indexes on wallet_address, token_type, block_number @@ -530,27 +636,27 @@ def get_pool_analysis( return_timestamp: bool = True, coerce_float=True, ) -> pd.DataFrame: - """Get all pool analysis and returns as a pandas dataframe. + """Get all pool analysis and returns a pandas dataframe. Arguments --------- session: Session - The initialized session object + The initialized session object. start_block: int | None, optional The starting block to filter the query on. start_block integers - matches python slicing notation, e.g., list[:3], list[:-3] + matches python slicing notation, e.g., list[:3], list[:-3]. end_block: int | None, optional The ending block to filter the query on. end_block integers - matches python slicing notation, e.g., list[:3], list[:-3] + matches python slicing notation, e.g., list[:3], list[:-3]. return_timestamp: bool, optional - Gets timestamps when looking at pool analysis. Defaults to True + Gets timestamps when looking at pool analysis. Defaults to True. coerce_float: bool - If true, will return floats in dataframe. Otherwise, will return fixed point Decimal + If true, will return floats in dataframe. Otherwise, will return fixed point Decimal. Returns ------- DataFrame - A DataFrame that consists of the queried pool info data + A DataFrame that consists of the queried pool info data. """ if return_timestamp: query = session.query(PoolInfo.timestamp, PoolAnalysis) @@ -587,31 +693,31 @@ def get_ticker( sort_desc: bool | None = False, coerce_float=True, ) -> pd.DataFrame: - """Get all pool analysis and returns as a pandas dataframe. + """Get all pool analysis and returns a pandas dataframe. Arguments --------- session: Session - The initialized session object + The initialized session object. start_block: int | None, optional The starting block to filter the query on. start_block integers - matches python slicing notation, e.g., list[:3], list[:-3] + matches python slicing notation, e.g., list[:3], list[:-3]. end_block: int | None, optional The ending block to filter the query on. end_block integers - matches python slicing notation, e.g., list[:3], list[:-3] + matches python slicing notation, e.g., list[:3], list[:-3]. wallet_address: list[str] | None, optional - The wallet addresses to filter the query on + The wallet addresses to filter the query on. max_rows: int | None - The number of rows to return. If None, will return all rows + The number of rows to return. If None, will return all rows. sort_desc: bool, optional - If true, will sort in descending order + If true, will sort in descending order. coerce_float: bool - If true, will return floats in dataframe. Otherwise, will return fixed point Decimal + If true, will return floats in dataframe. Otherwise, will return fixed point Decimal. Returns ------- DataFrame - A DataFrame that consists of the queried pool info data + A DataFrame that consists of the queried pool info data. """ # pylint: disable=too-many-arguments query = session.query(Ticker) @@ -653,29 +759,29 @@ def get_wallet_pnl( return_timestamp: bool = True, coerce_float=True, ) -> pd.DataFrame: - """Get all wallet pnl and returns as a pandas dataframe. + """Get all wallet pnl and returns a pandas dataframe. Arguments --------- session: Session - The initialized session object + The initialized session object. start_block: int | None, optional The starting block to filter the query on. start_block integers - matches python slicing notation, e.g., list[:3], list[:-3] + matches python slicing notation, e.g., list[:3], list[:-3]. end_block: int | None, optional The ending block to filter the query on. end_block integers - matches python slicing notation, e.g., list[:3], list[:-3] + matches python slicing notation, e.g., list[:3], list[:-3]. wallet_address: list[str] | None, optional The wallet addresses to filter the query on. Returns all if None. return_timestamp: bool, optional Returns the timestamp from the pool info table if True. Defaults to True. coerce_float: bool - If true, will return floats in dataframe. Otherwise, will return fixed point Decimal + If true, will return floats in dataframe. Otherwise, will return fixed point Decimal. Returns ------- DataFrame - A DataFrame that consists of the queried pool info data + A DataFrame that consists of the queried pool info data. """ if return_timestamp: query = session.query(PoolInfo.timestamp, WalletPNL) @@ -716,27 +822,27 @@ def get_total_wallet_pnl_over_time( wallet_address: list[str] | None = None, coerce_float=True, ) -> pd.DataFrame: - """Get total pnl across wallets over time and returns as a pandas dataframe. + """Get total pnl across wallets over time and returns a pandas dataframe. Arguments --------- session: Session - The initialized session object + The initialized session object. start_block: int | None, optional The starting block to filter the query on. start_block integers - matches python slicing notation, e.g., list[:3], list[:-3] + matches python slicing notation, e.g., list[:3], list[:-3]. end_block: int | None, optional The ending block to filter the query on. end_block integers - matches python slicing notation, e.g., list[:3], list[:-3] + matches python slicing notation, e.g., list[:3], list[:-3]. wallet_address: list[str] | None, optional The wallet addresses to filter the query on. Returns all if None. coerce_float: bool - If true, will return floats in dataframe. Otherwise, will return fixed point Decimal + If true, will return floats in dataframe. Otherwise, will return fixed point Decimal. Returns ------- DataFrame - A DataFrame that consists of the queried pool info data + A DataFrame that consists of the queried pool info data. """ # Do a subquery that groups wallet pnl by address and block # Not sure why func.sum is not callable, but it is @@ -778,27 +884,27 @@ def get_wallet_positions_over_time( wallet_address: list[str] | None = None, coerce_float=True, ) -> pd.DataFrame: - """Get wallet positions over time and returns as a pandas dataframe. + """Get wallet positions over time and returns a pandas dataframe. Arguments --------- session: Session - The initialized session object + The initialized session object. start_block: int | None, optional The starting block to filter the query on. start_block integers - matches python slicing notation, e.g., list[:3], list[:-3] + matches python slicing notation, e.g., list[:3], list[:-3]. end_block: int | None, optional The ending block to filter the query on. end_block integers - matches python slicing notation, e.g., list[:3], list[:-3] + matches python slicing notation, e.g., list[:3], list[:-3]. wallet_address: list[str] | None, optional The wallet addresses to filter the query on. Returns all if None. coerce_float: bool - If true, will return floats in dataframe. Otherwise, will return fixed point Decimal + If true, will return floats in dataframe. Otherwise, will return fixed point Decimal. Returns ------- DataFrame - A DataFrame that consists of the queried pool info data + A DataFrame that consists of the queried pool info data. """ # Not sure why func.sum is not callable, but it is subquery = session.query( diff --git a/src/agent0/chainsync/db/hyperdrive/interface_test.py b/src/agent0/chainsync/db/hyperdrive/interface_test.py index 679b97d6d0..31d3c6379d 100644 --- a/src/agent0/chainsync/db/hyperdrive/interface_test.py +++ b/src/agent0/chainsync/db/hyperdrive/interface_test.py @@ -15,18 +15,20 @@ add_pool_config, add_pool_infos, add_transactions, + add_transfer_events, add_wallet_deltas, get_all_traders, get_checkpoint_info, get_current_wallet, get_latest_block_number_from_pool_info_table, get_latest_block_number_from_table, + get_latest_block_number_from_trade_event, get_pool_config, get_pool_info, get_transactions, get_wallet_deltas, ) -from .schema import CheckpointInfo, CurrentWallet, HyperdriveTransaction, PoolConfig, PoolInfo, WalletDelta +from .schema import CheckpointInfo, CurrentWallet, HyperdriveTransaction, PoolConfig, PoolInfo, TradeEvent, WalletDelta # These tests are using fixtures defined in conftest.py @@ -157,14 +159,14 @@ class TestPoolConfigInterface: @pytest.mark.docker def test_get_pool_config(self, db_session): """Testing retrieval of pool config via interface""" - pool_config_1 = PoolConfig(contract_address="0", initial_vault_share_price=Decimal("3.2")) + pool_config_1 = PoolConfig(hyperdrive_address="0", initial_vault_share_price=Decimal("3.2")) add_pool_config(pool_config_1, db_session) pool_config_df_1 = get_pool_config(db_session) assert len(pool_config_df_1) == 1 np.testing.assert_array_equal(pool_config_df_1["initial_vault_share_price"], np.array([3.2])) - pool_config_2 = PoolConfig(contract_address="1", initial_vault_share_price=Decimal("3.4")) + pool_config_2 = PoolConfig(hyperdrive_address="1", initial_vault_share_price=Decimal("3.4")) add_pool_config(pool_config_2, db_session) pool_config_df_2 = get_pool_config(db_session) @@ -174,7 +176,7 @@ def test_get_pool_config(self, db_session): @pytest.mark.docker def test_primary_id_query_pool_config(self, db_session): """Testing retrieval of pool config via interface""" - pool_config = PoolConfig(contract_address="0", initial_vault_share_price=Decimal("3.2")) + pool_config = PoolConfig(hyperdrive_address="0", initial_vault_share_price=Decimal("3.2")) add_pool_config(pool_config, db_session) pool_config_df_1 = get_pool_config(db_session, contract_address="0") @@ -187,21 +189,21 @@ def test_primary_id_query_pool_config(self, db_session): @pytest.mark.docker def test_pool_config_verify(self, db_session): """Testing retrieval of pool config via interface""" - pool_config_1 = PoolConfig(contract_address="0", initial_vault_share_price=Decimal("3.2")) + pool_config_1 = PoolConfig(hyperdrive_address="0", initial_vault_share_price=Decimal("3.2")) add_pool_config(pool_config_1, db_session) pool_config_df_1 = get_pool_config(db_session) assert len(pool_config_df_1) == 1 assert pool_config_df_1.loc[0, "initial_vault_share_price"] == 3.2 # Nothing should happen if we give the same pool_config - pool_config_2 = PoolConfig(contract_address="0", initial_vault_share_price=Decimal("3.2")) + pool_config_2 = PoolConfig(hyperdrive_address="0", initial_vault_share_price=Decimal("3.2")) add_pool_config(pool_config_2, db_session) pool_config_df_2 = get_pool_config(db_session) assert len(pool_config_df_2) == 1 assert pool_config_df_2.loc[0, "initial_vault_share_price"] == 3.2 # If we try to add another pool config with a different value, should throw a ValueError - pool_config_3 = PoolConfig(contract_address="0", initial_vault_share_price=Decimal("3.4")) + pool_config_3 = PoolConfig(hyperdrive_address="0", initial_vault_share_price=Decimal("3.4")) with pytest.raises(ValueError): add_pool_config(pool_config_3, db_session) @@ -401,3 +403,24 @@ def test_current_wallet_info(self, db_session): wallet_info_df = wallet_info_df.sort_values(by=["value"]) np.testing.assert_array_equal(wallet_info_df["token_type"], ["LP", BASE_TOKEN_SYMBOL]) np.testing.assert_array_equal(wallet_info_df["value"], [5.1, 6.1]) + + +class TestHyperdriveEventsInterface: + """Testing postgres interface for walletinfo table""" + + @pytest.mark.docker + def test_latest_block_number(self, db_session): + """Testing retrieval of wallet info via interface""" + transfer_event = TradeEvent(block_number=1, hyperdrive_address="a", transaction_hash="a", wallet_address="a") + add_transfer_events([transfer_event], db_session) + + latest_block_number = get_latest_block_number_from_trade_event(db_session, "a") + assert latest_block_number == 1 + + transfer_event_1 = TradeEvent(block_number=2, hyperdrive_address="a", transaction_hash="a", wallet_address="a") + transfer_event_2 = TradeEvent(block_number=3, hyperdrive_address="a", transaction_hash="a", wallet_address="b") + add_transfer_events([transfer_event_1, transfer_event_2], db_session) + latest_block_number = get_latest_block_number_from_trade_event(db_session, "a") + assert latest_block_number == 2 + latest_block_number = get_latest_block_number_from_trade_event(db_session, "b") + assert latest_block_number == 3 diff --git a/src/agent0/chainsync/db/hyperdrive/schema.py b/src/agent0/chainsync/db/hyperdrive/schema.py index 0c6061cc05..0d90bbf5cb 100644 --- a/src/agent0/chainsync/db/hyperdrive/schema.py +++ b/src/agent0/chainsync/db/hyperdrive/schema.py @@ -28,7 +28,10 @@ class PoolConfig(Base): __tablename__ = "pool_config" - contract_address: Mapped[str] = mapped_column(String, primary_key=True) + # Indices + hyperdrive_address: Mapped[str] = mapped_column(String, primary_key=True) + + # Pool config parameters base_token: Mapped[Union[str, None]] = mapped_column(String, default=None) vault_shares_token: Mapped[Union[str, None]] = mapped_column(String, default=None) linker_factory: Mapped[Union[str, None]] = mapped_column(String, default=None) @@ -94,7 +97,64 @@ class PoolInfo(Base): vault_shares: Mapped[Union[Decimal, None]] = mapped_column(FIXED_NUMERIC, default=None) +class TradeEvent(Base): + """Table for storing any transfer events emitted by the Hyperdrive contract. + This table only contains events of "registered" wallet addresses, which are any agents + that are managed by agent0. This table does not store all wallet addresses that have + interacted with all Hyperdrive contracts. + TODO this table would take the place of the `WalletDelta` table with the following updates: + - We explicitly fill this table with all addresses that have interacted with all hyperdrive pools. + - This is very slow on existing pools, which makes it useful for simulations and + any managed chains to run a dashboard on, but not so much for connections to remote chains + to execute trades. + """ + + __tablename__ = "trade_event" + # Indices + id: Mapped[int] = mapped_column(BigInteger, primary_key=True, init=False, autoincrement=True) + """The unique identifier for the entry to the table.""" + hyperdrive_address: Mapped[str] = mapped_column(String, index=True) + """The hyperdrive address for the entry.""" + transaction_hash: Mapped[str] = mapped_column(String, index=True) + """The transaction hash for the entry.""" + block_number: Mapped[int] = mapped_column(BigInteger, index=True) + """The block number for the entry.""" + wallet_address: Mapped[str] = mapped_column(String, index=True) + """The wallet address for the entry.""" + + # Fields + event_type: Mapped[Union[str, None]] = mapped_column(String, index=True, default=None) + """ + The underlying event type for the entry. Can be one of the following: + `OpenLong`, `OpenShort`, `CloseLong`, `CloseShort`, `AddLiquidity`, + `RemoveLiquidity`, `RedeemWithdrawalShares`, or `TransferSingle`. + """ + token_type: Mapped[Union[str, None]] = mapped_column(String, index=True, default=None) + """ + The underlying token type for the entry. Can be one of the following: + `LONG`, `SHORT, `LP`, or `WITHDRAWAL_SHARE`. + """ + # While time here is in epoch seconds, we use Numeric to allow for + # (1) lossless storage and (2) allow for NaNs + maturity_time: Mapped[Union[int, None]] = mapped_column(Numeric, default=None) + """The maturity time of the token""" + token_id: Mapped[Union[str, None]] = mapped_column(String, default=None) + """ + The id for the token itself, which consists of the `token_type`, appended + with `maturity_time` for LONG and SHORT. For example, `LONG-1715126400`. + """ + token_delta: Mapped[Union[Decimal, None]] = mapped_column(FIXED_NUMERIC, default=None) + """ + The change in tokens with respect to the wallet address. + """ + base_delta: Mapped[Union[Decimal, None]] = mapped_column(FIXED_NUMERIC, default=None) + """ + The change in base tokens for the event with respect to the wallet address. + """ + + # TODO: either make a more general TokenDelta, or rename this to HyperdriveDelta +# TODO: this table might be able to be deprecated in favor of hyperdrive events. class WalletDelta(Base): """Table/dataclass schema for wallet deltas.""" @@ -114,6 +174,7 @@ class WalletDelta(Base): maturity_time: Mapped[Union[int, None]] = mapped_column(Numeric, default=None) +# TODO this table maybe isn't needed, use events table instead class HyperdriveTransaction(Base): """Table/dataclass schema for Transactions. diff --git a/src/agent0/chainsync/db/hyperdrive/schema_test.py b/src/agent0/chainsync/db/hyperdrive/schema_test.py index 70c5f7cedc..ede369bf8a 100644 --- a/src/agent0/chainsync/db/hyperdrive/schema_test.py +++ b/src/agent0/chainsync/db/hyperdrive/schema_test.py @@ -103,25 +103,25 @@ class TestPoolConfigTable: @pytest.mark.docker def test_create_pool_config(self, db_session): """Create and entry""" - pool_config = PoolConfig(contract_address="0", initial_vault_share_price=Decimal("3.2")) + pool_config = PoolConfig(hyperdrive_address="0", initial_vault_share_price=Decimal("3.2")) db_session.add(pool_config) db_session.commit() - retrieved_pool_config = db_session.query(PoolConfig).filter_by(contract_address="0").first() + retrieved_pool_config = db_session.query(PoolConfig).filter_by(hyperdrive_address="0").first() assert retrieved_pool_config is not None assert float(retrieved_pool_config.initial_vault_share_price) == 3.2 @pytest.mark.docker def test_delete_pool_config(self, db_session): """Delete an entry""" - pool_config = PoolConfig(contract_address="0", initial_vault_share_price=Decimal("3.2")) + pool_config = PoolConfig(hyperdrive_address="0", initial_vault_share_price=Decimal("3.2")) db_session.add(pool_config) db_session.commit() db_session.delete(pool_config) db_session.commit() - deleted_pool_config = db_session.query(PoolConfig).filter_by(contract_address="0").first() + deleted_pool_config = db_session.query(PoolConfig).filter_by(hyperdrive_address="0").first() assert deleted_pool_config is None diff --git a/src/agent0/chainsync/df_to_db.py b/src/agent0/chainsync/df_to_db.py new file mode 100644 index 0000000000..f7c83bbfa4 --- /dev/null +++ b/src/agent0/chainsync/df_to_db.py @@ -0,0 +1,47 @@ +"""Helper function to add a dataframe to a database.""" + +import logging +from typing import Type + +import pandas as pd +from sqlalchemy import exc +from sqlalchemy.orm import Session + +from agent0.chainsync.db.base import Base + +MAX_BATCH_SIZE = 10000 + + +def df_to_db(insert_df: pd.DataFrame, schema_obj: Type[Base], session: Session): + """Helper function to add a dataframe to a database. + + Arguments + --------- + insert_df: pd.DataFrame + The dataframe to insert. + schema_obj: Type[Base] + The schema object to use. + session: Session + The initialized session object. + """ + table_name = schema_obj.__tablename__ + + # dataframe to_sql needs data types from the schema object + dtype = {c.name: c.type for c in schema_obj.__table__.columns} + # Pandas doesn't play nice with types + insert_df.to_sql( + table_name, + con=session.connection(), + if_exists="append", + method="multi", + index=False, + dtype=dtype, # type: ignore + chunksize=MAX_BATCH_SIZE, + ) + # commit the transaction + try: + session.commit() + except exc.DataError as err: + session.rollback() + logging.error("Error on adding %s: %s", table_name, err) + raise err diff --git a/src/agent0/core/hyperdrive/interactive/hyperdrive.py b/src/agent0/core/hyperdrive/interactive/hyperdrive.py index f79b7dd10e..42557bb3ff 100644 --- a/src/agent0/core/hyperdrive/interactive/hyperdrive.py +++ b/src/agent0/core/hyperdrive/interactive/hyperdrive.py @@ -9,6 +9,7 @@ import nest_asyncio import numpy as np +import pandas as pd from eth_account.account import Account from eth_account.signers.local import LocalAccount from eth_typing import ChecksumAddress @@ -16,7 +17,13 @@ from numpy.random._generator import Generator from web3 import Web3 -from agent0.core.hyperdrive import HyperdriveActionType, HyperdrivePolicyAgent, TradeResult, TradeStatus +from agent0.core.hyperdrive import ( + HyperdriveActionType, + HyperdrivePolicyAgent, + HyperdriveWallet, + TradeResult, + TradeStatus, +) from agent0.core.hyperdrive.agent import ( add_liquidity_trade, build_wallet_positions_from_chain, @@ -234,7 +241,6 @@ def _init_agent( agent = HyperdrivePolicyAgent(Account().from_key(private_key), initial_budget=FixedPoint(0), policy=policy_obj) - self._sync_wallet(agent) return agent def _set_max_approval(self, agent: HyperdrivePolicyAgent) -> None: @@ -246,9 +252,14 @@ def _set_max_approval(self, agent: HyperdrivePolicyAgent) -> None: str(self.interface.hyperdrive_contract.address), ) - def _sync_wallet(self, agent: HyperdrivePolicyAgent) -> None: - # TODO add sync from db - agent.wallet = build_wallet_positions_from_chain(agent, self.interface) + def _get_positions(self, agent: HyperdrivePolicyAgent) -> HyperdriveWallet: + # TODO move the db to the chain class and use it here + # to get wallets + return build_wallet_positions_from_chain(agent, self.interface) + + def _get_trade_events(self, agent: HyperdrivePolicyAgent) -> pd.DataFrame: + # TODO move the db to the chain class and use it here to get events + raise NotImplementedError def _add_funds( self, diff --git a/src/agent0/core/hyperdrive/interactive/hyperdrive_agent.py b/src/agent0/core/hyperdrive/interactive/hyperdrive_agent.py index d0fa28a2c3..6f420bab99 100644 --- a/src/agent0/core/hyperdrive/interactive/hyperdrive_agent.py +++ b/src/agent0/core/hyperdrive/interactive/hyperdrive_agent.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING +import pandas as pd from fixedpointmath import FixedPoint if TYPE_CHECKING: @@ -62,17 +63,6 @@ def __init__( self._pool = pool self.agent = self._pool._init_agent(policy, policy_config, private_key) - @property - def wallet(self) -> HyperdriveWallet: - """Returns the agent's current wallet. - - Returns - ------- - HyperdriveWallet - The agent's current wallet. - """ - return self.agent.wallet - @property def checksum_address(self) -> ChecksumAddress: """Return the checksum address of the account.""" @@ -250,11 +240,26 @@ def set_max_approval(self) -> None: """ self._pool._set_max_approval(self.agent) - def sync_wallet_from_chain(self) -> None: - """Explicitly syncs the wallet to the current state of the chain. + def get_positions(self) -> HyperdriveWallet: + """Returns the agent's current wallet. - Uses on chain events to generate current wallet positions. + Returns + ------- + HyperdriveWallet + The agent's current wallet. + """ + + # Update the db with this wallet + return self._pool._get_positions(self.agent) + + def get_trade_events(self) -> pd.DataFrame: + """Returns the agent's current wallet. - .. note:: This function can be slow, use it sparingly. + Returns + ------- + HyperdriveWallet + The agent's current wallet. """ - self._pool._sync_wallet(self.agent) + + # Update the db with this wallet + return self._pool._get_trade_events(self.agent) diff --git a/src/agent0/core/hyperdrive/interactive/hyperdrive_test.py b/src/agent0/core/hyperdrive/interactive/hyperdrive_test.py index 872d8e978b..87690bc7cd 100644 --- a/src/agent0/core/hyperdrive/interactive/hyperdrive_test.py +++ b/src/agent0/core/hyperdrive/interactive/hyperdrive_test.py @@ -106,8 +106,8 @@ def test_remote_funding_and_trades(fast_chain_fixture: LocalChain, check_remote_ hyperdrive_agent1.set_max_approval() # Ensure agent wallet have expected balances - assert (hyperdrive_agent0.wallet.balance.amount) == FixedPoint(1_111_111) - assert (hyperdrive_agent1.wallet.balance.amount) == FixedPoint(222_222) + assert (hyperdrive_agent0.get_positions().balance.amount) == FixedPoint(1_111_111) + assert (hyperdrive_agent1.get_positions().balance.amount) == FixedPoint(222_222) # Ensure chain balances are as expected ( @@ -130,17 +130,17 @@ def test_remote_funding_and_trades(fast_chain_fixture: LocalChain, check_remote_ # Test trades add_liquidity_event = hyperdrive_agent0.add_liquidity(base=FixedPoint(111_111)) assert add_liquidity_event.base_amount == FixedPoint(111_111) - assert hyperdrive_agent0.wallet.lp_tokens == add_liquidity_event.lp_amount - _ensure_agent_wallet_is_correct(hyperdrive_agent0.wallet, interactive_remote_hyperdrive.interface) + assert hyperdrive_agent0.get_positions().lp_tokens == add_liquidity_event.lp_amount + _ensure_agent_wallet_is_correct(hyperdrive_agent0.get_positions(), interactive_remote_hyperdrive.interface) # Open long open_long_event = hyperdrive_agent0.open_long(base=FixedPoint(22_222)) assert open_long_event.base_amount == FixedPoint(22_222) - agent0_longs = list(hyperdrive_agent0.wallet.longs.values()) + agent0_longs = list(hyperdrive_agent0.get_positions().longs.values()) assert len(agent0_longs) == 1 assert agent0_longs[0].balance == open_long_event.bond_amount assert agent0_longs[0].maturity_time == open_long_event.maturity_time - _ensure_agent_wallet_is_correct(hyperdrive_agent0.wallet, interactive_remote_hyperdrive.interface) + _ensure_agent_wallet_is_correct(hyperdrive_agent0.get_positions(), interactive_remote_hyperdrive.interface) # Testing adding another agent to the pool after trades have been made, making a trade, # then checking wallet @@ -150,18 +150,18 @@ def test_remote_funding_and_trades(fast_chain_fixture: LocalChain, check_remote_ open_long_event_2 = hyperdrive_agent2.open_long(base=FixedPoint(333)) assert open_long_event_2.base_amount == FixedPoint(333) - agent2_longs = list(hyperdrive_agent2.wallet.longs.values()) + agent2_longs = list(hyperdrive_agent2.get_positions().longs.values()) assert len(agent2_longs) == 1 assert agent2_longs[0].balance == open_long_event_2.bond_amount assert agent2_longs[0].maturity_time == open_long_event_2.maturity_time - _ensure_agent_wallet_is_correct(hyperdrive_agent2.wallet, interactive_remote_hyperdrive.interface) + _ensure_agent_wallet_is_correct(hyperdrive_agent2.get_positions(), interactive_remote_hyperdrive.interface) # Remove liquidity remove_liquidity_event = hyperdrive_agent0.remove_liquidity(shares=add_liquidity_event.lp_amount) assert add_liquidity_event.lp_amount == remove_liquidity_event.lp_amount - assert hyperdrive_agent0.wallet.lp_tokens == FixedPoint(0) - assert hyperdrive_agent0.wallet.withdraw_shares == remove_liquidity_event.withdrawal_share_amount - _ensure_agent_wallet_is_correct(hyperdrive_agent0.wallet, interactive_remote_hyperdrive.interface) + assert hyperdrive_agent0.get_positions().lp_tokens == FixedPoint(0) + assert hyperdrive_agent0.get_positions().withdraw_shares == remove_liquidity_event.withdrawal_share_amount + _ensure_agent_wallet_is_correct(hyperdrive_agent0.get_positions(), interactive_remote_hyperdrive.interface) # We ensure there exists some withdrawal shares that were given from the previous trade for testing purposes assert remove_liquidity_event.withdrawal_share_amount > 0 @@ -169,11 +169,11 @@ def test_remote_funding_and_trades(fast_chain_fixture: LocalChain, check_remote_ # Open short open_short_event = hyperdrive_agent0.open_short(bonds=FixedPoint(333)) assert open_short_event.bond_amount == FixedPoint(333) - agent0_shorts = list(hyperdrive_agent0.wallet.shorts.values()) + agent0_shorts = list(hyperdrive_agent0.get_positions().shorts.values()) assert len(agent0_shorts) == 1 assert agent0_shorts[0].balance == open_short_event.bond_amount assert agent0_shorts[0].maturity_time == open_short_event.maturity_time - _ensure_agent_wallet_is_correct(hyperdrive_agent0.wallet, interactive_remote_hyperdrive.interface) + _ensure_agent_wallet_is_correct(hyperdrive_agent0.get_positions(), interactive_remote_hyperdrive.interface) # Close long close_long_event = hyperdrive_agent0.close_long( @@ -181,8 +181,8 @@ def test_remote_funding_and_trades(fast_chain_fixture: LocalChain, check_remote_ ) assert open_long_event.bond_amount == close_long_event.bond_amount assert open_long_event.maturity_time == close_long_event.maturity_time - assert len(hyperdrive_agent0.wallet.longs) == 0 - _ensure_agent_wallet_is_correct(hyperdrive_agent0.wallet, interactive_remote_hyperdrive.interface) + assert len(hyperdrive_agent0.get_positions().longs) == 0 + _ensure_agent_wallet_is_correct(hyperdrive_agent0.get_positions(), interactive_remote_hyperdrive.interface) # Close short close_short_event = hyperdrive_agent0.close_short( @@ -190,18 +190,18 @@ def test_remote_funding_and_trades(fast_chain_fixture: LocalChain, check_remote_ ) assert open_short_event.bond_amount == close_short_event.bond_amount assert open_short_event.maturity_time == close_short_event.maturity_time - assert len(hyperdrive_agent0.wallet.shorts) == 0 - _ensure_agent_wallet_is_correct(hyperdrive_agent0.wallet, interactive_remote_hyperdrive.interface) + assert len(hyperdrive_agent0.get_positions().shorts) == 0 + _ensure_agent_wallet_is_correct(hyperdrive_agent0.get_positions(), interactive_remote_hyperdrive.interface) # Redeem withdrawal shares # Note that redeeming withdrawal shares for more than available in the pool # will pull out as much withdrawal shares as possible redeem_event = hyperdrive_agent0.redeem_withdraw_share(shares=remove_liquidity_event.withdrawal_share_amount) assert ( - hyperdrive_agent0.wallet.withdraw_shares + hyperdrive_agent0.get_positions().withdraw_shares == remove_liquidity_event.withdrawal_share_amount - redeem_event.withdrawal_share_amount ) - _ensure_agent_wallet_is_correct(hyperdrive_agent0.wallet, interactive_remote_hyperdrive.interface) + _ensure_agent_wallet_is_correct(hyperdrive_agent0.get_positions(), interactive_remote_hyperdrive.interface) @pytest.mark.anvil @@ -223,61 +223,3 @@ def test_no_policy_call(fast_chain_fixture: LocalChain, check_remote_chain: bool # Attempt to execute agent policy, should throw value error with pytest.raises(ValueError): hyperdrive_agent.execute_policy_action() - - -@pytest.mark.anvil -@pytest.mark.parametrize("check_remote_chain", [True, False]) -def test_sync_wallet_from_chain(fast_chain_fixture: LocalChain, check_remote_chain: bool): - """Deploy a local chain and point the remote interface to the local chain.""" - # Parameters for pool initialization. If empty, defaults to default values, allows for custom values if needed - # We explicitly set initial liquidity here to ensure we have withdrawal shares when trading - initial_pool_config = LocalHyperdrive.Config( - initial_liquidity=FixedPoint(1_000), - initial_fixed_apr=FixedPoint("0.05"), - position_duration=60 * 60 * 24 * 365, # 1 year - ) - # Launches a local hyperdrive pool - # This deploys the pool - interactive_local_hyperdrive = LocalHyperdrive(fast_chain_fixture, initial_pool_config) - - # Gather relevant objects from the local hyperdrive - hyperdrive_addresses = interactive_local_hyperdrive.get_hyperdrive_address() - - # Connect to the local chain using the remote hyperdrive interface - if check_remote_chain: - remote_chain = Chain(fast_chain_fixture.rpc_uri) - interactive_remote_hyperdrive = Hyperdrive(remote_chain, hyperdrive_addresses) - else: - interactive_remote_hyperdrive = Hyperdrive(fast_chain_fixture, hyperdrive_addresses) - - # Generate trading agents from the interactive object using the same underlying wallet - private_key = make_private_key() - hyperdrive_local_agent = interactive_local_hyperdrive.init_agent(private_key=private_key, eth=FixedPoint(10)) - hyperdrive_remote_agent = interactive_remote_hyperdrive.init_agent(private_key=private_key) - - # TODO check balance of calls in this test - - # Add funds to the local agent, remote agent wallet doesn't update - hyperdrive_local_agent.add_funds(base=FixedPoint(1_111_111), eth=FixedPoint(111)) - hyperdrive_local_agent.add_liquidity(base=FixedPoint(111_111)) - hyperdrive_local_agent.open_long(base=FixedPoint(222)) - hyperdrive_local_agent.open_short(bonds=FixedPoint(333)) - - # Sync the remote agent wallet from the chain and ensure it's correct - hyperdrive_remote_agent.sync_wallet_from_chain() - assert len(hyperdrive_remote_agent.wallet.longs) == 1 - assert len(hyperdrive_remote_agent.wallet.shorts) == 1 - # We only check balances of shorts - assert list(hyperdrive_remote_agent.wallet.shorts.values())[0].balance == FixedPoint(333) - - # Add funds to the remote agent and see local agent wallet doesn't update - hyperdrive_remote_agent.add_funds(base=FixedPoint(1_111_111), eth=FixedPoint(111)) - hyperdrive_remote_agent.add_liquidity(base=FixedPoint(111_111)) - hyperdrive_remote_agent.open_long(base=FixedPoint(222)) - hyperdrive_remote_agent.open_short(bonds=FixedPoint(333)) - - # Sync local wallet from chain and ensure it's correct - hyperdrive_local_agent.sync_wallet_from_chain() - assert len(hyperdrive_local_agent.wallet.longs) == 1 - assert len(hyperdrive_local_agent.wallet.shorts) == 1 - assert list(hyperdrive_local_agent.wallet.shorts.values())[0].balance == FixedPoint(666) diff --git a/src/agent0/core/hyperdrive/interactive/local_hyperdrive.py b/src/agent0/core/hyperdrive/interactive/local_hyperdrive.py index e6ff710faa..d61d87d514 100644 --- a/src/agent0/core/hyperdrive/interactive/local_hyperdrive.py +++ b/src/agent0/core/hyperdrive/interactive/local_hyperdrive.py @@ -18,6 +18,7 @@ from eth_account.signers.local import LocalAccount from eth_typing import BlockNumber, ChecksumAddress from fixedpointmath import FixedPoint +from hexbytes import HexBytes from IPython.display import IFrame from sqlalchemy_utils.functions import drop_database from web3._utils.threads import Timeout @@ -38,15 +39,20 @@ get_pool_analysis, get_pool_config, get_pool_info, + get_positions_from_db, get_ticker, get_total_wallet_pnl_over_time, + get_trade_events, get_wallet_deltas, get_wallet_pnl, + trade_events_to_db, ) from agent0.chainsync.exec import acquire_data, data_analysis +from agent0.core.base import Quantity, TokenType from agent0.core.base.make_key import make_private_key -from agent0.core.hyperdrive import HyperdrivePolicyAgent, TradeResult, TradeStatus +from agent0.core.hyperdrive import HyperdrivePolicyAgent, HyperdriveWallet, TradeResult, TradeStatus from agent0.core.hyperdrive.agent import build_wallet_positions_from_db +from agent0.core.hyperdrive.agent.hyperdrive_wallet import Long, Short from agent0.core.hyperdrive.crash_report import get_anvil_state_dump from agent0.core.hyperdrive.policies import HyperdriveBasePolicy from agent0.ethpy.hyperdrive import ( @@ -1081,12 +1087,58 @@ def _init_local_agent( add_addr_to_username(name, [agent.address], self.db_session) return agent - def _sync_wallet(self, agent: HyperdrivePolicyAgent) -> None: - # TODO add sync from db - super()._sync_wallet(agent) - # Ensure db is up to date - if not self.chain.experimental_data_threading: - self._run_blocking_data_pipeline() + def _sync_events(self, agent: HyperdrivePolicyAgent) -> None: + # Update the db with this wallet + # TODO this function can be optimized to cache. + trade_events_to_db([self.interface], agent.checksum_address, self.db_session) + + def _get_positions(self, agent: HyperdrivePolicyAgent) -> HyperdriveWallet: + self._sync_events(agent) + + # Query for the wallet object from the db + wallet_positions = get_positions_from_db( + self.db_session, agent.checksum_address, hyperdrive_address=self.interface.hyperdrive_address + ) + # Convert to hyperdrive wallet object + long_obj: dict[int, Long] = {} + short_obj: dict[int, Short] = {} + lp_balance: FixedPoint = FixedPoint(0) + withdrawal_shares_balance: FixedPoint = FixedPoint(0) + for _, row in wallet_positions.iterrows(): + # Sanity checks + assert row["hyperdrive_address"] == self.interface.hyperdrive_address + assert row["wallet_address"] == agent.checksum_address + if row["token_id"] == "LP": + lp_balance = FixedPoint(row["balance"]) + elif row["token_id"] == "WITHDRAWAL_SHARE": + withdrawal_shares_balance = FixedPoint(row["balance"]) + elif "LONG" in row["token_id"]: + maturity_time = int(row["token_id"].split("-")[1]) + long_obj[maturity_time] = Long(balance=FixedPoint(row["balance"]), maturity_time=maturity_time) + elif "SHORT" in row["token_id"]: + maturity_time = int(row["token_id"].split("-")[1]) + short_obj[maturity_time] = Short(balance=FixedPoint(row["balance"]), maturity_time=maturity_time) + + # We do a balance of call to get base balance. + base_balance = FixedPoint( + scaled_value=self.interface.base_token_contract.functions.balanceOf(agent.checksum_address).call() + ) + + return HyperdriveWallet( + address=HexBytes(agent.checksum_address), + balance=Quantity( + amount=base_balance, + unit=TokenType.BASE, + ), + lp_tokens=lp_balance, + withdraw_shares=withdrawal_shares_balance, + longs=long_obj, + shorts=short_obj, + ) + + def _get_trade_events(self, agent: HyperdrivePolicyAgent) -> pd.DataFrame: + self._sync_events(agent) + return get_trade_events(self.db_session, agent.checksum_address) def _add_funds( self, diff --git a/src/agent0/core/hyperdrive/interactive/local_hyperdrive_test.py b/src/agent0/core/hyperdrive/interactive/local_hyperdrive_test.py index 6542eb65bf..9f7845a90b 100644 --- a/src/agent0/core/hyperdrive/interactive/local_hyperdrive_test.py +++ b/src/agent0/core/hyperdrive/interactive/local_hyperdrive_test.py @@ -100,9 +100,9 @@ def test_funding_and_trades(fast_chain_fixture: LocalChain): hyperdrive_agent2.add_funds(base=FixedPoint(333_333), eth=FixedPoint(333)) # Ensure agent wallet have expected balances - assert (hyperdrive_agent0.wallet.balance.amount) == FixedPoint(1_111_111) - assert (hyperdrive_agent1.wallet.balance.amount) == FixedPoint(222_222) - assert (hyperdrive_agent2.wallet.balance.amount) == FixedPoint(333_333) + assert (hyperdrive_agent0.get_positions().balance.amount) == FixedPoint(1_111_111) + assert (hyperdrive_agent1.get_positions().balance.amount) == FixedPoint(222_222) + assert (hyperdrive_agent2.get_positions().balance.amount) == FixedPoint(333_333) # Ensure chain balances are as expected ( chain_eth_balance, @@ -131,24 +131,24 @@ def test_funding_and_trades(fast_chain_fixture: LocalChain): # Add liquidity to 112_111 total add_liquidity_event = hyperdrive_agent0.add_liquidity(base=FixedPoint(111_111)) assert add_liquidity_event.base_amount == FixedPoint(111_111) - assert hyperdrive_agent0.wallet.lp_tokens == add_liquidity_event.lp_amount - _ensure_db_wallet_matches_agent_wallet(interactive_hyperdrive, hyperdrive_agent0.wallet) + assert hyperdrive_agent0.get_positions().lp_tokens == add_liquidity_event.lp_amount + _ensure_db_wallet_matches_agent_wallet(interactive_hyperdrive, hyperdrive_agent0.get_positions()) # Open long open_long_event = hyperdrive_agent0.open_long(base=FixedPoint(22_222)) assert open_long_event.base_amount == FixedPoint(22_222) - agent0_longs = list(hyperdrive_agent0.wallet.longs.values()) + agent0_longs = list(hyperdrive_agent0.get_positions().longs.values()) assert len(agent0_longs) == 1 assert agent0_longs[0].balance == open_long_event.bond_amount assert agent0_longs[0].maturity_time == open_long_event.maturity_time - _ensure_db_wallet_matches_agent_wallet(interactive_hyperdrive, hyperdrive_agent0.wallet) + _ensure_db_wallet_matches_agent_wallet(interactive_hyperdrive, hyperdrive_agent0.get_positions()) # Remove liquidity remove_liquidity_event = hyperdrive_agent0.remove_liquidity(shares=add_liquidity_event.lp_amount) assert add_liquidity_event.lp_amount == remove_liquidity_event.lp_amount - assert hyperdrive_agent0.wallet.lp_tokens == FixedPoint(0) - assert hyperdrive_agent0.wallet.withdraw_shares == remove_liquidity_event.withdrawal_share_amount - _ensure_db_wallet_matches_agent_wallet(interactive_hyperdrive, hyperdrive_agent0.wallet) + assert hyperdrive_agent0.get_positions().lp_tokens == FixedPoint(0) + assert hyperdrive_agent0.get_positions().withdraw_shares == remove_liquidity_event.withdrawal_share_amount + _ensure_db_wallet_matches_agent_wallet(interactive_hyperdrive, hyperdrive_agent0.get_positions()) # We ensure there exists some withdrawal shares that were given from the previous trade for testing purposes assert remove_liquidity_event.withdrawal_share_amount > 0 @@ -156,11 +156,11 @@ def test_funding_and_trades(fast_chain_fixture: LocalChain): # Open short open_short_event = hyperdrive_agent0.open_short(bonds=FixedPoint(333)) assert open_short_event.bond_amount == FixedPoint(333) - agent0_shorts = list(hyperdrive_agent0.wallet.shorts.values()) + agent0_shorts = list(hyperdrive_agent0.get_positions().shorts.values()) assert len(agent0_shorts) == 1 assert agent0_shorts[0].balance == open_short_event.bond_amount assert agent0_shorts[0].maturity_time == open_short_event.maturity_time - _ensure_db_wallet_matches_agent_wallet(interactive_hyperdrive, hyperdrive_agent0.wallet) + _ensure_db_wallet_matches_agent_wallet(interactive_hyperdrive, hyperdrive_agent0.get_positions()) # Close long close_long_event = hyperdrive_agent0.close_long( @@ -168,8 +168,8 @@ def test_funding_and_trades(fast_chain_fixture: LocalChain): ) assert open_long_event.bond_amount == close_long_event.bond_amount assert open_long_event.maturity_time == close_long_event.maturity_time - assert len(hyperdrive_agent0.wallet.longs) == 0 - _ensure_db_wallet_matches_agent_wallet(interactive_hyperdrive, hyperdrive_agent0.wallet) + assert len(hyperdrive_agent0.get_positions().longs) == 0 + _ensure_db_wallet_matches_agent_wallet(interactive_hyperdrive, hyperdrive_agent0.get_positions()) # Close short close_short_event = hyperdrive_agent0.close_short( @@ -177,14 +177,14 @@ def test_funding_and_trades(fast_chain_fixture: LocalChain): ) assert open_short_event.bond_amount == close_short_event.bond_amount assert open_short_event.maturity_time == close_short_event.maturity_time - assert len(hyperdrive_agent0.wallet.shorts) == 0 - _ensure_db_wallet_matches_agent_wallet(interactive_hyperdrive, hyperdrive_agent0.wallet) + assert len(hyperdrive_agent0.get_positions().shorts) == 0 + _ensure_db_wallet_matches_agent_wallet(interactive_hyperdrive, hyperdrive_agent0.get_positions()) # Redeem withdrawal shares redeem_event = hyperdrive_agent0.redeem_withdraw_share(shares=remove_liquidity_event.withdrawal_share_amount) assert redeem_event.withdrawal_share_amount == remove_liquidity_event.withdrawal_share_amount - assert hyperdrive_agent0.wallet.withdraw_shares == FixedPoint(0) - _ensure_db_wallet_matches_agent_wallet(interactive_hyperdrive, hyperdrive_agent0.wallet) + assert hyperdrive_agent0.get_positions().withdraw_shares == FixedPoint(0) + _ensure_db_wallet_matches_agent_wallet(interactive_hyperdrive, hyperdrive_agent0.get_positions()) @pytest.mark.anvil @@ -309,7 +309,7 @@ def test_save_load_snapshot(chain_fixture: LocalChain): # and in the db # Check base balance on the chain init_eth_on_chain, init_base_on_chain = hyperdrive_interface.get_eth_base_balances(hyperdrive_agent.agent) - init_agent_wallet = hyperdrive_agent.wallet.copy() + init_agent_wallet = hyperdrive_agent.get_positions().copy() init_db_wallet = interactive_hyperdrive.get_current_wallet(coerce_float=False).copy() init_pool_info_on_chain = interactive_hyperdrive.interface.get_hyperdrive_state().pool_info init_pool_state_on_db = interactive_hyperdrive.get_pool_state(coerce_float=False) @@ -324,7 +324,7 @@ def test_save_load_snapshot(chain_fixture: LocalChain): check_eth_on_chain, check_base_on_chain, ) = hyperdrive_interface.get_eth_base_balances(hyperdrive_agent.agent) - check_agent_wallet = hyperdrive_agent.wallet + check_agent_wallet = hyperdrive_agent.get_positions() check_db_wallet = interactive_hyperdrive.get_current_wallet(coerce_float=False) check_pool_info_on_chain = interactive_hyperdrive.interface.get_hyperdrive_state().pool_info check_pool_state_on_db = interactive_hyperdrive.get_pool_state(coerce_float=False) @@ -343,7 +343,7 @@ def test_save_load_snapshot(chain_fixture: LocalChain): check_eth_on_chain, check_base_on_chain, ) = hyperdrive_interface.get_eth_base_balances(hyperdrive_agent.agent) - check_agent_wallet = hyperdrive_agent.wallet + check_agent_wallet = hyperdrive_agent.get_positions() check_db_wallet = interactive_hyperdrive.get_current_wallet(coerce_float=False) check_pool_info_on_chain = interactive_hyperdrive.interface.get_hyperdrive_state().pool_info check_pool_state_on_db = interactive_hyperdrive.get_pool_state(coerce_float=False) @@ -367,7 +367,7 @@ def test_save_load_snapshot(chain_fixture: LocalChain): check_eth_on_chain, check_base_on_chain, ) = hyperdrive_interface.get_eth_base_balances(hyperdrive_agent.agent) - check_agent_wallet = hyperdrive_agent.wallet + check_agent_wallet = hyperdrive_agent.get_positions() check_db_wallet = interactive_hyperdrive.get_current_wallet(coerce_float=False) check_pool_info_on_chain = interactive_hyperdrive.interface.get_hyperdrive_state().pool_info check_pool_state_on_db = interactive_hyperdrive.get_pool_state(coerce_float=False) @@ -386,7 +386,7 @@ def test_save_load_snapshot(chain_fixture: LocalChain): check_eth_on_chain, check_base_on_chain, ) = hyperdrive_interface.get_eth_base_balances(hyperdrive_agent.agent) - check_agent_wallet = hyperdrive_agent.wallet + check_agent_wallet = hyperdrive_agent.get_positions() check_db_wallet = interactive_hyperdrive.get_current_wallet(coerce_float=False) check_pool_info_on_chain = interactive_hyperdrive.interface.get_hyperdrive_state().pool_info check_pool_state_on_db = interactive_hyperdrive.get_pool_state(coerce_float=False) @@ -410,7 +410,7 @@ def test_save_load_snapshot(chain_fixture: LocalChain): check_eth_on_chain, check_base_on_chain, ) = hyperdrive_interface.get_eth_base_balances(hyperdrive_agent.agent) - check_agent_wallet = hyperdrive_agent.wallet + check_agent_wallet = hyperdrive_agent.get_positions() check_db_wallet = interactive_hyperdrive.get_current_wallet(coerce_float=False) check_pool_info_on_chain = interactive_hyperdrive.interface.get_hyperdrive_state().pool_info check_pool_state_on_db = interactive_hyperdrive.get_pool_state(coerce_float=False) @@ -429,7 +429,7 @@ def test_save_load_snapshot(chain_fixture: LocalChain): check_eth_on_chain, check_base_on_chain, ) = hyperdrive_interface.get_eth_base_balances(hyperdrive_agent.agent) - check_agent_wallet = hyperdrive_agent.wallet + check_agent_wallet = hyperdrive_agent.get_positions() check_db_wallet = interactive_hyperdrive.get_current_wallet(coerce_float=False) check_pool_info_on_chain = interactive_hyperdrive.interface.get_hyperdrive_state().pool_info check_pool_state_on_db = interactive_hyperdrive.get_pool_state(coerce_float=False) @@ -475,8 +475,8 @@ def test_access_deployer_account(fast_chain_fixture: LocalChain): larry = interactive_hyperdrive.init_agent( base=FixedPoint(100_000), eth=FixedPoint(10), name="larry", private_key=privkey ) - larry_pubkey = larry.wallet.address.hex().strip("0x").lower() - assert larry_pubkey == pubkey.lower().strip("0x") # deployer public key + larry_pubkey = larry.checksum_address + assert larry_pubkey == pubkey # deployer public key @pytest.mark.anvil @@ -497,10 +497,10 @@ def test_access_deployer_liquidity(fast_chain_fixture: LocalChain): larry.checksum_address, ).call() ) - == larry.wallet.lp_tokens + == larry.get_positions().lp_tokens ) # Hyperdrive pool steals 2 * minimumShareReserves from the initial deployer's liquidity - assert larry.wallet.lp_tokens == config.initial_liquidity - 2 * config.minimum_share_reserves + assert larry.get_positions().lp_tokens == config.initial_liquidity - 2 * config.minimum_share_reserves @pytest.mark.anvil @@ -514,8 +514,8 @@ def test_remove_deployer_liquidity(fast_chain_fixture: LocalChain): larry = interactive_hyperdrive.init_agent( base=FixedPoint(100_000), eth=FixedPoint(10), name="larry", private_key=privkey ) - larry.remove_liquidity(shares=larry.wallet.lp_tokens) - assert larry.wallet.lp_tokens == 0 + larry.remove_liquidity(shares=larry.get_positions().lp_tokens) + assert larry.get_positions().lp_tokens == 0 assert ( FixedPoint( scaled_value=interactive_hyperdrive.interface.hyperdrive_contract.functions.balanceOf( diff --git a/src/agent0/hyperfuzz/system_fuzz/run_fuzz_bots.py b/src/agent0/hyperfuzz/system_fuzz/run_fuzz_bots.py index af52ce55fe..d317dc05be 100644 --- a/src/agent0/hyperfuzz/system_fuzz/run_fuzz_bots.py +++ b/src/agent0/hyperfuzz/system_fuzz/run_fuzz_bots.py @@ -336,7 +336,7 @@ def run_fuzz_bots( # Check agent funds and refund if necessary assert len(agents) > 0 - average_agent_base = sum(agent.wallet.balance.amount for agent in agents) / FixedPoint(len(agents)) + average_agent_base = sum(agent.get_positions().balance.amount for agent in agents) / FixedPoint(len(agents)) # Update agent funds if average_agent_base < minimum_avg_agent_base: logging.info("Refunding agents...") diff --git a/src/agent0/hyperfuzz/unit_fuzz/fuzz_present_value.py b/src/agent0/hyperfuzz/unit_fuzz/fuzz_present_value.py index de23dc5d1b..310d018e92 100644 --- a/src/agent0/hyperfuzz/unit_fuzz/fuzz_present_value.py +++ b/src/agent0/hyperfuzz/unit_fuzz/fuzz_present_value.py @@ -91,12 +91,12 @@ def fuzz_present_value( HyperdriveActionType.REMOVE_LIQUIDITY, ]: # Keep the agent flush - if agent.wallet.balance.amount < FixedPoint("1e10"): - agent.add_funds(base=FixedPoint("1e10") - agent.wallet.balance.amount) + if agent.get_positions().balance.amount < FixedPoint("1e10"): + agent.add_funds(base=FixedPoint("1e10") - agent.get_positions().balance.amount) # Set up trade amount bounds min_trade = interactive_hyperdrive.interface.pool_config.minimum_transaction_amount - max_budget = agent.wallet.balance.amount + max_budget = agent.get_positions().balance.amount trade_amount = None # Execute the trade @@ -110,7 +110,7 @@ def fuzz_present_value( ) trade_event = agent.open_long(base=trade_amount) case HyperdriveActionType.CLOSE_LONG: - maturity_time, open_trade = next(iter(agent.wallet.longs.items())) + maturity_time, open_trade = next(iter(agent.get_positions().longs.items())) trade_event = agent.close_long(maturity_time=maturity_time, bonds=open_trade.balance) case HyperdriveActionType.OPEN_SHORT: max_trade = interactive_hyperdrive.interface.calc_max_short( @@ -121,7 +121,7 @@ def fuzz_present_value( ) trade_event = agent.open_short(trade_amount) case HyperdriveActionType.CLOSE_SHORT: - maturity_time, open_trade = next(iter(agent.wallet.shorts.items())) + maturity_time, open_trade = next(iter(agent.get_positions().shorts.items())) trade_event = agent.close_short(maturity_time=maturity_time, bonds=open_trade.balance) case HyperdriveActionType.ADD_LIQUIDITY: # recompute initial present value for liquidity actions @@ -130,7 +130,11 @@ def fuzz_present_value( ) trade_amount = FixedPoint( scaled_value=int( - np.floor(rng.uniform(low=min_trade.scaled_value, high=agent.wallet.balance.amount.scaled_value)) + np.floor( + rng.uniform( + low=min_trade.scaled_value, high=agent.get_positions().balance.amount.scaled_value + ) + ) ) ) trade_event = agent.add_liquidity(trade_amount) @@ -139,8 +143,8 @@ def fuzz_present_value( check_data["initial_present_value"] = interactive_hyperdrive.interface.calc_present_value( interactive_hyperdrive.interface.current_pool_state ) - trade_amount = agent.wallet.lp_tokens - trade_event = agent.remove_liquidity(agent.wallet.lp_tokens) + trade_amount = agent.get_positions().lp_tokens + trade_event = agent.remove_liquidity(agent.get_positions().lp_tokens) case _: raise ValueError(f"Invalid {trade_type=}") diff --git a/src/agent0/hyperfuzz/unit_fuzz/fuzz_profit_check.py b/src/agent0/hyperfuzz/unit_fuzz/fuzz_profit_check.py index e5b6b3cbe1..4dfee2eed1 100644 --- a/src/agent0/hyperfuzz/unit_fuzz/fuzz_profit_check.py +++ b/src/agent0/hyperfuzz/unit_fuzz/fuzz_profit_check.py @@ -90,7 +90,7 @@ def fuzz_profit_check(chain_config: LocalChain.Config | None = None): # Generate funded trading agent long_agent = interactive_hyperdrive.init_agent(base=long_trade_amount, eth=FixedPoint(100), name="alice") - long_agent_initial_balance = long_agent.wallet.balance.amount + long_agent_initial_balance = long_agent.get_positions().balance.amount # Advance time to be right after a checkpoint boundary logging.info("Advance time...") @@ -134,7 +134,7 @@ def fuzz_profit_check(chain_config: LocalChain.Config | None = None): # the short trade amount is in bonds, but we know we will need much less base # we can play it safe by initializing with that much base short_agent = interactive_hyperdrive.init_agent(base=short_trade_amount, eth=FixedPoint(100), name="bob") - short_agent_initial_balance = short_agent.wallet.balance.amount + short_agent_initial_balance = short_agent.get_positions().balance.amount # Advance time to be right after a checkpoint boundary logging.info("Advance time...") @@ -165,10 +165,10 @@ def fuzz_profit_check(chain_config: LocalChain.Config | None = None): check_data = { "long_trade_amount": long_trade_amount, "long_agent_initial_balance": long_agent_initial_balance, - "long_agent_final_balance": long_agent.wallet.balance.amount, + "long_agent_final_balance": long_agent.get_positions().balance.amount, "long_events": {"open": open_long_event, "close": close_long_event}, "short_trade_amount": short_trade_amount, - "short_agent_final_balance": short_agent.wallet.balance.amount, + "short_agent_final_balance": short_agent.get_positions().balance.amount, "short_agent_initial_balance": short_agent_initial_balance, "short_events": {"open": open_short_event, "close": close_short_event}, } diff --git a/src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py b/src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py index 51b42d9fe2..99b02d4854 100644 --- a/src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py +++ b/src/agent0/traiderdaive/gym_environments/full_hyperdrive_env.py @@ -345,7 +345,7 @@ def _apply_action(self, action: np.ndarray) -> bool: if open_order: # If the wallet has enough money - if volume_adjusted <= self.rl_bot.wallet.balance.amount: + if volume_adjusted <= self.rl_bot.get_positions().balance.amount: try: if trade_type == TradeTypes.LONG: self.rl_bot.open_long(base=volume_adjusted) @@ -382,12 +382,12 @@ def _apply_action(self, action: np.ndarray) -> bool: try: if add_lp: self.rl_bot.add_liquidity(add_lp_volume) - if remove_lp and remove_lp_volume <= self.rl_bot.wallet.lp_tokens: + if remove_lp and remove_lp_volume <= self.rl_bot.get_positions().lp_tokens: self.rl_bot.remove_liquidity(remove_lp_volume) # Always try and remove withdrawal shares - if self.rl_bot.wallet.withdraw_shares > 0: + if self.rl_bot.get_positions().withdraw_shares > 0: # TODO error handling or check when withdrawal shares are not withdrawable - self.rl_bot.redeem_withdraw_share(self.rl_bot.wallet.withdraw_shares) + self.rl_bot.redeem_withdraw_share(self.rl_bot.get_positions().withdraw_shares) except Exception as err: # pylint: disable=broad-except # TODO use logging here print(f"Warning: Failed to LP: {err=}") diff --git a/src/agent0/traiderdaive/gym_environments/simple_hyperdrive_env.py b/src/agent0/traiderdaive/gym_environments/simple_hyperdrive_env.py index ea8e2e9249..36fa631ba5 100644 --- a/src/agent0/traiderdaive/gym_environments/simple_hyperdrive_env.py +++ b/src/agent0/traiderdaive/gym_environments/simple_hyperdrive_env.py @@ -250,7 +250,7 @@ def do_trade(self) -> bool: assert self._current_position is not None terminated = False - agent_wallet = self.rl_bot.wallet + agent_wallet = self.rl_bot.get_positions() self._current_position = self._current_position.opposite() if self._current_position == CurrentPosition.LONG: diff --git a/tests/bot_to_db_test.py b/tests/bot_to_db_test.py index 07ae1547bd..0e85980cb8 100644 --- a/tests/bot_to_db_test.py +++ b/tests/bot_to_db_test.py @@ -109,7 +109,7 @@ def test_bot_to_db( base_token_addr = fast_hyperdrive_fixture._deployed_hyperdrive.base_token_contract.address vault_shares_token_addr = fast_hyperdrive_fixture._deployed_hyperdrive.vault_shares_token_contract.address expected_pool_config = { - "contract_address": fast_hyperdrive_fixture.get_hyperdrive_address(), + "hyperdrive_address": fast_hyperdrive_fixture.get_hyperdrive_address(), "base_token": base_token_addr, "vault_shares_token": vault_shares_token_addr, "initial_vault_share_price": _to_unscaled_decimal(FixedPoint("1")),