Skip to content

Commit

Permalink
Rewrite for EVM chains
Browse files Browse the repository at this point in the history
  • Loading branch information
philogicae committed Aug 25, 2024
1 parent e6e1133 commit b6aa0e4
Show file tree
Hide file tree
Showing 3 changed files with 241 additions and 150 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies = [
"python-magic",
"typing_extensions",
"aioresponses>=0.7.6",
"superfluid@git+https://github.com/aleph-im/superfluid.py.git@1yam-add-base",
"superfluid@git+https://github.com/1yam/superfluid.py.git@1yam-add-base",
"web3==6.3.0",
]

Expand Down
245 changes: 181 additions & 64 deletions src/aleph/sdk/chains/ethereum.py
Original file line number Diff line number Diff line change
@@ -1,103 +1,230 @@
import asyncio
from decimal import Decimal
from pathlib import Path
from typing import Awaitable, Dict, Optional, Set, Union
from typing import Awaitable, List, Optional, Union

from aleph_message.models import Chain
from eth_account import Account
from eth_account.messages import encode_defunct
from eth_account.signers.local import LocalAccount
from eth_keys.exceptions import BadSignature as EthBadSignatureError
from eth_utils import to_wei
from superfluid import Web3FlowInfo
from web3 import Web3
from web3.middleware import geth_poa_middleware
from web3.types import ChecksumAddress, TxParams, TxReceipt

from aleph.sdk.exceptions import InsufficientFundsError

from ..conf import settings
from ..connectors.superfluid import Superfluid
from ..exceptions import BadSignatureError
from ..utils import bytes_from_hex
from .common import BaseAccount, get_fallback_private_key, get_public_key

CHAINS_WITH_SUPERTOKEN: Set[Chain] = {Chain.AVAX}
CHAIN_IDS: Dict[Chain, int] = {
Chain.AVAX: settings.AVAX_CHAIN_ID,
}
MIN_ETH_BALANCE: float = 0.005
MIN_ETH_BALANCE_WEI = Decimal(to_wei(MIN_ETH_BALANCE, "ether"))
BALANCEOF_ABI = """[{
"name": "balanceOf",
"inputs": [{"name": "account", "type": "address"}],
"outputs": [{"name": "balance", "type": "uint256"}],
"constant": true,
"payable": false,
"stateMutability": "view",
"type": "function"
}]"""


def get_rpc_for_chain(chain: Chain):
"""Returns the RPC to use for a given Ethereum based blockchain"""
if not chain:
return None
def to_human_readable_token(amount: Decimal) -> float:
return float(amount / (Decimal(10) ** Decimal(settings.TOKEN_DECIMALS)))

if chain == Chain.AVAX:
return settings.AVAX_RPC
else:
raise ValueError(f"Unknown RPC for chain {chain}")

def to_wei_token(amount: Decimal) -> Decimal:
return amount * Decimal(10) ** Decimal(settings.TOKEN_DECIMALS)

def get_chain_id_for_chain(chain: Chain):
"""Returns the chain ID of a given Ethereum based blockchain"""
if not chain:
return None

if chain in CHAIN_IDS:
return CHAIN_IDS[chain]
else:
raise ValueError(f"Unknown RPC for chain {chain}")
def get_chain_id(chain: Union[Chain, str, None]) -> Optional[int]:
"""Returns the CHAIN_ID of a given EVM blockchain"""
if chain:
if chain in settings.CHAINS and settings.CHAINS[chain].chain_id:
return settings.CHAINS[chain].chain_id
else:
raise ValueError(f"Unknown RPC for chain {chain}")
return None


def get_rpc(chain: Union[Chain, str, None]) -> Optional[str]:
"""Returns the RPC to use for a given EVM blockchain"""
if chain:
if chain in settings.CHAINS and settings.CHAINS[chain].rpc:
return settings.CHAINS[chain].rpc
else:
raise ValueError(f"Unknown RPC for chain {chain}")
return None


def get_token_address(chain: Union[Chain, str, None]) -> Optional[ChecksumAddress]:
if chain:
if chain in settings.CHAINS:
address = settings.CHAINS[chain].super_token
if address:
try:
return Web3.to_checksum_address(address)
except ValueError:
raise ValueError(f"Invalid token address {address}")
else:
raise ValueError(f"Unknown token for chain {chain}")
return None


def get_super_token_address(
chain: Union[Chain, str, None]
) -> Optional[ChecksumAddress]:
if chain:
if chain in settings.CHAINS:
address = settings.CHAINS[chain].super_token
if address:
try:
return Web3.to_checksum_address(address)
except ValueError:
raise ValueError(f"Invalid token address {address}")
else:
raise ValueError(f"Unknown super_token for chain {chain}")
return None


def get_chains_with_super_token() -> List[Union[Chain, str]]:
return [chain for chain, info in settings.CHAINS.items() if info.super_token]


class ETHAccount(BaseAccount):
"""Interact with an Ethereum address or key pair"""
"""Interact with an Ethereum address or key pair on EVM blockchains"""

CHAIN = "ETH"
CURVE = "secp256k1"
_account: LocalAccount
_provider: Optional[Web3]
chain: Optional[Chain]
chain_id: Optional[int]
rpc: Optional[str]
superfluid_connector: Optional[Superfluid]

def __init__(
self,
private_key: bytes,
chain: Optional[Chain] = None,
rpc: Optional[str] = None,
chain_id: Optional[int] = None,
):
self.private_key = private_key
self._account = Account.from_key(self.private_key)
self.chain = chain
rpc = rpc or get_rpc_for_chain(chain)
chain_id = chain_id or get_chain_id_for_chain(chain)
self.superfluid_connector = (
Superfluid(
rpc=rpc,
chain_id=chain_id,
account=self._account,
)
if chain in CHAINS_WITH_SUPERTOKEN
else None
self._account: LocalAccount = Account.from_key(private_key)
self.connect_chain(chain=chain)

@staticmethod
def from_mnemonic(mnemonic: str, chain: Optional[Chain] = None) -> "ETHAccount":
Account.enable_unaudited_hdwallet_features()
return ETHAccount(
private_key=Account.from_mnemonic(mnemonic=mnemonic).key, chain=chain
)

def get_address(self) -> str:
return self._account.address

def get_public_key(self) -> str:
return "0x" + get_public_key(private_key=self._account.key).hex()

async def sign_raw(self, buffer: bytes) -> bytes:
"""Sign a raw buffer."""
msghash = encode_defunct(text=buffer.decode("utf-8"))
sig = self._account.sign_message(msghash)
return sig["signature"]

def get_address(self) -> str:
return self._account.address
def connect_chain(self, chain: Optional[Chain] = None):
self.chain = chain
if self.chain:
self.chain_id = get_chain_id(self.chain)
self.rpc = get_rpc(self.chain)
self._provider = Web3(Web3.HTTPProvider(self.rpc))
if chain == Chain.BSC:
self._provider.middleware_onion.inject(
geth_poa_middleware, "geth_poa", layer=0
)
else:
self.chain_id = None
self.rpc = None
self._provider = None

def get_public_key(self) -> str:
return "0x" + get_public_key(private_key=self._account.key).hex()
if chain in get_chains_with_super_token() and self._provider:
self.superfluid_connector = Superfluid(self)
else:
self.superfluid_connector = None

@staticmethod
def from_mnemonic(mnemonic: str) -> "ETHAccount":
Account.enable_unaudited_hdwallet_features()
return ETHAccount(private_key=Account.from_mnemonic(mnemonic=mnemonic).key)
def switch_chain(self, chain: Optional[Chain] = None):
self.connect_chain(chain=chain)

def can_transact(self, block=True) -> bool:
balance = self.get_eth_balance()
valid = balance > MIN_ETH_BALANCE_WEI if self.chain else False
if not valid and block:
raise InsufficientFundsError(
required_funds=MIN_ETH_BALANCE,
available_funds=to_human_readable_token(balance),
)
return valid

async def _sign_and_send_transaction(self, tx_params: TxParams) -> str:
"""
Sign and broadcast a transaction using the provided ETHAccount
@param tx_params - Transaction parameters
@returns - str - Transaction hash
"""
self.can_transact()

def sign_and_send() -> TxReceipt:
if self._provider is None:
raise ValueError("Provider not connected")
signed_tx = self._provider.eth.account.sign_transaction(
tx_params, self._account.key
)
tx_hash = self._provider.eth.send_raw_transaction(signed_tx.rawTransaction)
tx_receipt = self._provider.eth.wait_for_transaction_receipt(
tx_hash, settings.TX_TIMEOUT
)
return tx_receipt

loop = asyncio.get_running_loop()
tx_receipt = await loop.run_in_executor(None, sign_and_send)
return tx_receipt["transactionHash"].hex()

def get_eth_balance(self) -> Decimal:
return Decimal(
self._provider.eth.get_balance(self._account.address)
if self._provider
else 0
)

def get_token_balance(self) -> Decimal:
if self.chain and self._provider:
contact_address = get_token_address(self.chain)
if contact_address:
contract = self._provider.eth.contract(
address=contact_address, abi=BALANCEOF_ABI
)
return Decimal(contract.functions.balanceOf(self.get_address()).call())
return Decimal(0)

def get_super_token_balance(self) -> Decimal:
if self.chain and self._provider:
contact_address = get_super_token_address(self.chain)
if contact_address:
contract = self._provider.eth.contract(
address=contact_address, abi=BALANCEOF_ABI
)
return Decimal(contract.functions.balanceOf(self.get_address()).call())
return Decimal(0)

def create_flow(self, receiver: str, flow: Decimal) -> Awaitable[str]:
"""Creat a Superfluid flow between this account and the receiver address."""
if not self.superfluid_connector:
raise ValueError("Superfluid connector is required to create a flow")
return self.superfluid_connector.create_flow(
sender=self.get_address(), receiver=receiver, flow=flow
)
return self.superfluid_connector.create_flow(receiver=receiver, flow=flow)

def get_flow(self, receiver: str) -> Awaitable[Web3FlowInfo]:
"""Get the Superfluid flow between this account and the receiver address."""
Expand All @@ -111,29 +238,19 @@ def update_flow(self, receiver: str, flow: Decimal) -> Awaitable[str]:
"""Update the Superfluid flow between this account and the receiver address."""
if not self.superfluid_connector:
raise ValueError("Superfluid connector is required to update a flow")
return self.superfluid_connector.update_flow(
sender=self.get_address(), receiver=receiver, flow=flow
)
return self.superfluid_connector.update_flow(receiver=receiver, flow=flow)

def delete_flow(self, receiver: str) -> Awaitable[str]:
"""Delete the Superfluid flow between this account and the receiver address."""
if not self.superfluid_connector:
raise ValueError("Superfluid connector is required to delete a flow")
return self.superfluid_connector.delete_flow(
sender=self.get_address(), receiver=receiver
)

def update_superfluid_connector(self, rpc: str, chain_id: int):
"""Update the Superfluid connector after initialisation."""
self.superfluid_connector = Superfluid(
rpc=rpc,
chain_id=chain_id,
account=self._account,
)
return self.superfluid_connector.delete_flow(receiver=receiver)


def get_fallback_account(path: Optional[Path] = None) -> ETHAccount:
return ETHAccount(private_key=get_fallback_private_key(path=path))
def get_fallback_account(
path: Optional[Path] = None, chain: Optional[Chain] = None
) -> ETHAccount:
return ETHAccount(private_key=get_fallback_private_key(path=path), chain=chain)


def verify_signature(
Expand Down
Loading

0 comments on commit b6aa0e4

Please sign in to comment.