Skip to content

Commit

Permalink
Feature: Account Handler (#175)
Browse files Browse the repository at this point in the history
* Feature: Internal account management + fix on _load_account to handle SolAccount

* fixup! Feature: Internal account management + fix on _load_account to handle SolAccount

* Fix: chains_config wasn't using settings.CONFIG_HOME for locations

* Fix: blakc issue

* Fix: rename CHAINS_CONFIG_FILE to CONFIG_FILE to avoid getting issue by conf of chain

* Fix: base58 and pynacl is now needed for build

* Fix: f string without nay placeholders

* Fix: black error

* Refactor: we now store single account at the time

* Fix: ruff issue

* fix: debug stuff remove

* Fix: Improve code structure in pair-programming with Lyam

---------

Co-authored-by: Andres D. Molins <[email protected]>
  • Loading branch information
1yam and nesitor authored Oct 2, 2024
1 parent cf70462 commit 2790df9
Show file tree
Hide file tree
Showing 6 changed files with 263 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ MANIFEST
**/device.key

# environment variables
.env
.config.json
.env.local

.gitsigners
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ dependencies = [
"aleph-superfluid>=0.2.1",
"eth_typing==4.3.1",
"web3==6.3.0",
"base58==2.1.1", # Needed now as default with _load_account changement
"pynacl==1.5.0" # Needed now as default with _load_account changement
]

[project.optional-dependencies]
Expand Down
52 changes: 43 additions & 9 deletions src/aleph/sdk/account.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,32 @@
import asyncio
import logging
from pathlib import Path
from typing import Optional, Type, TypeVar
from typing import Dict, Optional, Type, TypeVar

from aleph_message.models import Chain

from aleph.sdk.chains.common import get_fallback_private_key
from aleph.sdk.chains.ethereum import ETHAccount
from aleph.sdk.chains.remote import RemoteAccount
from aleph.sdk.conf import settings
from aleph.sdk.chains.solana import SOLAccount
from aleph.sdk.conf import load_main_configuration, settings
from aleph.sdk.types import AccountFromPrivateKey

logger = logging.getLogger(__name__)

T = TypeVar("T", bound=AccountFromPrivateKey)


def load_chain_account_type(chain: Chain) -> Type[AccountFromPrivateKey]:
chain_account_map: Dict[Chain, Type[AccountFromPrivateKey]] = {
Chain.ETH: ETHAccount,
Chain.AVAX: ETHAccount,
Chain.SOL: SOLAccount,
Chain.BASE: ETHAccount,
}
return chain_account_map.get(chain) or ETHAccount


def account_from_hex_string(private_key_str: str, account_type: Type[T]) -> T:
if private_key_str.startswith("0x"):
private_key_str = private_key_str[2:]
Expand All @@ -28,16 +41,36 @@ def account_from_file(private_key_path: Path, account_type: Type[T]) -> T:
def _load_account(
private_key_str: Optional[str] = None,
private_key_path: Optional[Path] = None,
account_type: Type[AccountFromPrivateKey] = ETHAccount,
account_type: Optional[Type[AccountFromPrivateKey]] = None,
) -> AccountFromPrivateKey:
"""Load private key from a string or a file. takes the string argument in priority"""
if private_key_str or (private_key_path and private_key_path.is_file()):
if account_type:
if private_key_path and private_key_path.is_file():
return account_from_file(private_key_path, account_type)
elif private_key_str:
return account_from_hex_string(private_key_str, account_type)
else:
raise ValueError("Any private key specified")
else:
main_configuration = load_main_configuration(settings.CONFIG_FILE)
if main_configuration:
account_type = load_chain_account_type(main_configuration.chain)
logger.debug(
f"Detected {main_configuration.chain} account for path {settings.CONFIG_FILE}"
)
else:
account_type = ETHAccount # Defaults to ETHAccount
logger.warning(
f"No main configuration data found in {settings.CONFIG_FILE}, defaulting to {account_type.__name__}"
)
if private_key_path and private_key_path.is_file():
return account_from_file(private_key_path, account_type)
elif private_key_str:
return account_from_hex_string(private_key_str, account_type)
else:
raise ValueError("Any private key specified")

if private_key_str:
logger.debug("Using account from string")
return account_from_hex_string(private_key_str, account_type)
elif private_key_path and private_key_path.is_file():
logger.debug("Using account from file")
return account_from_file(private_key_path, account_type)
elif settings.REMOTE_CRYPTO_HOST:
logger.debug("Using remote account")
loop = asyncio.get_event_loop()
Expand All @@ -48,6 +81,7 @@ def _load_account(
)
)
else:
account_type = ETHAccount # Defaults to ETHAccount
new_private_key = get_fallback_private_key()
account = account_type(private_key=new_private_key)
logger.info(
Expand Down
94 changes: 91 additions & 3 deletions src/aleph/sdk/chains/solana.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Dict, List, Optional, Union

import base58
from nacl.exceptions import BadSignatureError as NaclBadSignatureError
Expand All @@ -22,7 +22,7 @@ class SOLAccount(BaseAccount):
_private_key: PrivateKey

def __init__(self, private_key: bytes):
self.private_key = private_key
self.private_key = parse_private_key(private_key_from_bytes(private_key))
self._signing_key = SigningKey(self.private_key)
self._private_key = self._signing_key.to_curve25519_private_key()

Expand Down Expand Up @@ -79,7 +79,7 @@ def verify_signature(
public_key: The public key to use for verification. Can be a base58 encoded string or bytes.
message: The message to verify. Can be an utf-8 string or bytes.
Raises:
BadSignatureError: If the signature is invalid.
BadSignatureError: If the signature is invalid.!
"""
if isinstance(signature, str):
signature = base58.b58decode(signature)
Expand All @@ -91,3 +91,91 @@ def verify_signature(
VerifyKey(public_key).verify(message, signature)
except NaclBadSignatureError as e:
raise BadSignatureError from e


def private_key_from_bytes(
private_key_bytes: bytes, output_format: str = "base58"
) -> Union[str, List[int], bytes]:
"""
Convert a Solana private key in bytes back to different formats (base58 string, uint8 list, or raw bytes).
- For base58 string: Encode the bytes into a base58 string.
- For uint8 list: Convert the bytes into a list of integers.
- For raw bytes: Return as-is.
Args:
private_key_bytes (bytes): The private key in byte format.
output_format (str): The format to return ('base58', 'list', 'bytes').
Returns:
The private key in the requested format.
Raises:
ValueError: If the output_format is not recognized or the private key length is invalid.
"""
if not isinstance(private_key_bytes, bytes):
raise ValueError("Expected the private key in bytes.")

if len(private_key_bytes) != 32:
raise ValueError("Solana private key must be exactly 32 bytes long.")

if output_format == "base58":
return base58.b58encode(private_key_bytes).decode("utf-8")

elif output_format == "list":
return list(private_key_bytes)

elif output_format == "bytes":
return private_key_bytes

else:
raise ValueError("Invalid output format. Choose 'base58', 'list', or 'bytes'.")


def parse_private_key(private_key: Union[str, List[int], bytes]) -> bytes:
"""
Parse the private key which could be either:
- a base58-encoded string (which may contain both private and public key)
- a list of uint8 integers (which may contain both private and public key)
- a byte array (exactly 32 bytes)
Returns:
bytes: The private key in byte format (32 bytes).
Raises:
ValueError: If the private key format is invalid or the length is incorrect.
"""
# If the private key is already in byte format
if isinstance(private_key, bytes):
if len(private_key) != 32:
raise ValueError("The private key in bytes must be exactly 32 bytes long.")
return private_key

# If the private key is a base58-encoded string
elif isinstance(private_key, str):
try:
decoded_key = base58.b58decode(private_key)
if len(decoded_key) not in [32, 64]:
raise ValueError(
"The base58 decoded private key must be either 32 or 64 bytes long."
)
return decoded_key[:32]
except Exception as e:
raise ValueError(f"Invalid base58 encoded private key: {e}")

# If the private key is a list of uint8 integers
elif isinstance(private_key, list):
if all(isinstance(i, int) and 0 <= i <= 255 for i in private_key):
byte_key = bytes(private_key)
if len(byte_key) < 32:
raise ValueError("The uint8 array must contain at least 32 elements.")
return byte_key[:32] # Take the first 32 bytes (private key)
else:
raise ValueError(
"Invalid uint8 array, must contain integers between 0 and 255."
)

else:
raise ValueError(
"Unsupported private key format. Must be a base58 string, bytes, or a list of uint8 integers."
)
68 changes: 67 additions & 1 deletion src/aleph/sdk/conf.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,27 @@
import json
import logging
import os
from pathlib import Path
from shutil import which
from typing import Dict, Optional, Union

from aleph_message.models import Chain
from aleph_message.models.execution.environment import HypervisorType
from pydantic import BaseSettings, Field
from pydantic import BaseModel, BaseSettings, Field

from aleph.sdk.types import ChainInfo

logger = logging.getLogger(__name__)


class Settings(BaseSettings):
CONFIG_HOME: Optional[str] = None

CONFIG_FILE: Path = Field(
default=Path("config.json"),
description="Path to the JSON file containing chain account configurations",
)

# In case the user does not want to bother with handling private keys himself,
# do an ugly and insecure write and read from disk to this file.
PRIVATE_KEY_FILE: Path = Field(
Expand Down Expand Up @@ -139,6 +148,18 @@ class Config:
env_file = ".env"


class MainConfiguration(BaseModel):
"""
Intern Chain Management with Account.
"""

path: Path
chain: Chain

class Config:
use_enum_values = True


# Settings singleton
settings = Settings()

Expand All @@ -162,6 +183,19 @@ class Config:
settings.PRIVATE_MNEMONIC_FILE = Path(
settings.CONFIG_HOME, "private-keys", "substrate.mnemonic"
)
if str(settings.CONFIG_FILE) == "config.json":
settings.CONFIG_FILE = Path(settings.CONFIG_HOME, "config.json")
# If Config file exist and well filled we update the PRIVATE_KEY_FILE default
if settings.CONFIG_FILE.exists():
try:
with open(settings.CONFIG_FILE, "r", encoding="utf-8") as f:
config_data = json.load(f)

if "path" in config_data:
settings.PRIVATE_KEY_FILE = Path(config_data["path"])
except json.JSONDecodeError:
pass


# Update CHAINS settings and remove placeholders
CHAINS_ENV = [(key[7:], value) for key, value in settings if key.startswith("CHAINS_")]
Expand All @@ -172,3 +206,35 @@ class Config:
field = field.lower()
settings.CHAINS[chain].__dict__[field] = value
settings.__delattr__(f"CHAINS_{fields}")


def save_main_configuration(file_path: Path, data: MainConfiguration):
"""
Synchronously save a single ChainAccount object as JSON to a file.
"""
with file_path.open("w") as file:
data_serializable = data.dict()
data_serializable["path"] = str(data_serializable["path"])
json.dump(data_serializable, file, indent=4)


def load_main_configuration(file_path: Path) -> Optional[MainConfiguration]:
"""
Synchronously load the private key and chain type from a file.
If the file does not exist or is empty, return None.
"""
if not file_path.exists() or file_path.stat().st_size == 0:
logger.debug(f"File {file_path} does not exist or is empty. Returning None.")
return None

try:
with file_path.open("rb") as file:
content = file.read()
data = json.loads(content.decode("utf-8"))
return MainConfiguration(**data)
except UnicodeDecodeError as e:
logger.error(f"Unable to decode {file_path} as UTF-8: {e}")
except json.JSONDecodeError:
logger.error(f"Invalid JSON format in {file_path}.")

return None
60 changes: 59 additions & 1 deletion tests/unit/test_chain_solana.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
from nacl.signing import VerifyKey

from aleph.sdk.chains.common import get_verification_buffer
from aleph.sdk.chains.solana import SOLAccount, get_fallback_account, verify_signature
from aleph.sdk.chains.solana import (
SOLAccount,
get_fallback_account,
parse_private_key,
verify_signature,
)
from aleph.sdk.exceptions import BadSignatureError


Expand Down Expand Up @@ -136,3 +141,56 @@ async def test_sign_raw(solana_account):
assert isinstance(signature, bytes)

verify_signature(signature, solana_account.get_address(), buffer)


def test_parse_solana_private_key_bytes():
# Valid 32-byte private key
private_key_bytes = bytes(range(32))
parsed_key = parse_private_key(private_key_bytes)
assert isinstance(parsed_key, bytes)
assert len(parsed_key) == 32
assert parsed_key == private_key_bytes

# Invalid private key (too short)
with pytest.raises(
ValueError, match="The private key in bytes must be exactly 32 bytes long."
):
parse_private_key(bytes(range(31)))


def test_parse_solana_private_key_base58():
# Valid base58 private key (32 bytes)
base58_key = base58.b58encode(bytes(range(32))).decode("utf-8")
parsed_key = parse_private_key(base58_key)
assert isinstance(parsed_key, bytes)
assert len(parsed_key) == 32

# Invalid base58 key (not decodable)
with pytest.raises(ValueError, match="Invalid base58 encoded private key"):
parse_private_key("invalid_base58_key")

# Invalid base58 key (wrong length)
with pytest.raises(
ValueError,
match="The base58 decoded private key must be either 32 or 64 bytes long.",
):
parse_private_key(base58.b58encode(bytes(range(31))).decode("utf-8"))


def test_parse_solana_private_key_list():
# Valid list of uint8 integers (64 elements, but we only take the first 32 for private key)
uint8_list = list(range(64))
parsed_key = parse_private_key(uint8_list)
assert isinstance(parsed_key, bytes)
assert len(parsed_key) == 32
assert parsed_key == bytes(range(32))

# Invalid list (contains non-integers)
with pytest.raises(ValueError, match="Invalid uint8 array"):
parse_private_key([1, 2, "not an int", 4]) # type: ignore # Ignore type check for string

# Invalid list (less than 32 elements)
with pytest.raises(
ValueError, match="The uint8 array must contain at least 32 elements."
):
parse_private_key(list(range(31)))

0 comments on commit 2790df9

Please sign in to comment.