Skip to content

Commit

Permalink
WIP: Use pydantic classes in control endpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
hoh committed Nov 14, 2023
1 parent 308fe59 commit b89239f
Showing 1 changed file with 133 additions and 61 deletions.
194 changes: 133 additions & 61 deletions src/aleph/vm/orchestrator/views/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from collections.abc import Awaitable
from datetime import datetime, timedelta, timezone
from typing import Callable, Literal, Dict, Any, Union
from typing import Any, Callable, Literal, Union

import aiohttp.web_exceptions
from aiohttp import web
Expand All @@ -14,10 +14,11 @@
from eth_account import Account
from eth_account.messages import encode_defunct
from jwskate import Jwk
from pydantic import root_validator, validator
from pydantic.main import BaseModel

from ...models import VmExecution
from ...pool import VmPool
from aleph.vm.models import VmExecution
from aleph.vm.pool import VmPool

logger = logging.getLogger(__name__)

Expand All @@ -38,38 +39,111 @@ def verify_wallet_signature(signature, message, address):
Verifies a signature issued by a wallet
"""
enc_msg = encode_defunct(hexstr=message)
# Todo: Catch exception
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/root/aleph-vm-remote/src/aleph/vm/orchestrator/views/operator.py", line 40, in verify_wallet_signature
# computed_address = Account.recover_message(enc_msg, signature=signature)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/opt/aleph-vm/eth_utils/decorators.py", line 20, in _wrapper
# return self.method(objtype, *args, **kwargs)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/opt/aleph-vm/eth_account/account.py", line 463, in recover_message
# return cast(ChecksumAddress, self._recover_hash(message_hash, vrs, signature))
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/opt/aleph-vm/eth_utils/decorators.py", line 20, in _wrapper
# return self.method(objtype, *args, **kwargs)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/opt/aleph-vm/eth_account/account.py", line 481, in _recover_hash
# signature_bytes_standard = to_standard_signature_bytes(signature_bytes)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/opt/aleph-vm/eth_account/_utils/signing.py", line 105, in to_standard_signature_bytes
# standard_v = to_standard_v(v)
# ^^^^^^^^^^^^^^^^
# File "/opt/aleph-vm/eth_account/_utils/signing.py", line 110, in to_standard_v
# (_chain, chain_naive_v) = extract_chain_id(enhanced_v)
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# File "/opt/aleph-vm/eth_account/_utils/signing.py", line 96, in extract_chain_id
# raise ValueError("v %r is invalid, must be one of: 0, 1, 27, 28, 35+")
# ValueError: v %r is invalid, must be one of: 0, 1, 27, 28, 35+
computed_address = Account.recover_message(enc_msg, signature=signature)

return computed_address.lower() == address.lower()


def get_json_from_hex(str: str):
"""
Converts a hex string to a json object
"""
return json.loads(bytes.fromhex(str).decode("utf-8"))


class SignedPubKeyHeader(BaseModel):
signature: str # hexadecimal
payload: str # hexadecimal of SignedPubKeyPayload


class SignedPubKeyPayload(BaseModel):
"""This payload is signed by the wallet of the user to authorize an ephemeral key to act on his behalf."""
# pubkey: Jwk
pubkey: Dict[str, Any]
# {'pubkey': {'alg': 'ES256', 'crv': 'P-256', 'ext': True, 'key_ops': ['verify'], 'kty': 'EC', 'x': '4blJBYpltvQLFgRvLE-2H7dsMr5O0ImHkgOnjUbG2AU', 'y': '5VHnq_hUSogZBbVgsXMs0CjrVfMy4Pa3Uv2BEBqfrN4'}

pubkey: dict[str, Any]
# {'pubkey': {'alg': 'ES256', 'crv': 'P-256', 'ext': True, 'key_ops': ['verify'], 'kty': 'EC',
# 'x': '4blJBYpltvQLFgRvLE-2H7dsMr5O0ImHkgOnjUbG2AU', 'y': '5VHnq_hUSogZBbVgsXMs0CjrVfMy4Pa3Uv2BEBqfrN4'}
# alg: Literal["ECDSA"]
domain: str
address: str
expires: float # timestamp # TODO: move to ISO 8601

@property
def json_web_key(self) -> Jwk:
"""Return the ephemeral public key as Json Web Key"""
return Jwk(self.pubkey)


class SignedPubKeyHeader(BaseModel):
signature: bytes
payload: bytes

@validator("signature")
def signature_must_be_hex(cls, v: bytes) -> bytes:
"""Convert the signature from hexadecimal to bytes"""
v = v.strip(b"0x")
return bytes.fromhex(v.decode())

@validator("payload")
def payload_must_be_hex(cls, v: bytes) -> bytes:
"""Convert the payload from hexadecimal to bytes"""
return bytes.fromhex(v.decode())

@root_validator(pre=False, skip_on_failure=True)
def check_expiry(cls, values):
"""Check that the token has not expired"""
payload: bytes = values["payload"]
content = SignedPubKeyPayload.parse_raw(payload)
if not is_token_still_valid(content.expires):
msg = "Token expired"
raise ValueError(msg)
return values

@root_validator(pre=False, skip_on_failure=True)
def check_signature(cls, values):
"""Check that the signature is valid"""
signature: bytes = values["signature"]
payload: bytes = values["payload"]
content = SignedPubKeyPayload.parse_raw(payload)
if not verify_wallet_signature(signature, payload.hex(), content.address):
msg = "Invalid signature"
raise ValueError(msg)
return values

@property
def content(self) -> SignedPubKeyPayload:
"""Return the content of the header"""
return SignedPubKeyPayload.parse_raw(self.payload)


class SignedOperation(BaseModel):
"""This payload is signed by the ephemeral key authorized above."""
signature: str # hexadecimal
payload: str # hexadecimal of SignedOperationPayload

signature: bytes
payload: bytes

@validator("signature")
def signature_must_be_hex(cls, v) -> bytes:
"""Convert the signature from hexadecimal to bytes"""
v = v.strip(b"0x")
return bytes.fromhex(v.decode())

@validator("payload")
def payload_must_be_hex(cls, v) -> bytes:
"""Convert the payload from hexadecimal to bytes"""
return bytes.fromhex(v.decode())


class SignedOperationPayload(BaseModel):
Expand All @@ -79,52 +153,50 @@ class SignedOperationPayload(BaseModel):
# body_sha256: str # disabled since there is no body


async def authenticate_jwk(request: web.Request) -> str:
# The ephemeral public key that is signed by the wallet.
signed_pubkey = request.headers.get("X-SignedPubKey")
if not signed_pubkey:
def get_signed_pubkey(request: web.Request) -> SignedPubKeyHeader:
"""Get the ephemeral public key that is signed by the wallet from the request headers."""
signed_pubkey_header = request.headers.get("X-SignedPubKey")
if not signed_pubkey_header:
raise web.HTTPBadRequest(reason="Missing X-SignedPubKey header")

# Check that the header is a valid JSON object and deserialize it
try:
pubkey_dict = json.loads(signed_pubkey)
pubkey_body = SignedPubKeyHeader.parse_obj(pubkey_dict)
except (json.JSONDecodeError, KeyError):
raise web.HTTPBadRequest(reason="Invalid X-SignedPubKey format")

# Deserialize the payload from the header
return SignedPubKeyHeader.parse_raw(signed_pubkey_header)
except KeyError as error:
logger.debug(f"Missing X-SignedPubKey header: {error}")
raise web.HTTPBadRequest(reason="Invalid X-SignedPubKey fields") from error
except json.JSONDecodeError as error:
raise web.HTTPBadRequest(reason="Invalid X-SignedPubKey format") from error
except ValueError as error:
if error.args == ("Token expired",):
raise web.HTTPUnauthorized(reason="Token expired") from error
elif error.args == ("Invalid signature",):
raise web.HTTPUnauthorized(reason="Invalid signature") from error
else:
raise error


def get_signed_operation(request: web.Request) -> SignedOperation:
"""Get the signed operation public key that is signed by the ephemeral key from the request headers."""
try:
json_payload = SignedPubKeyPayload.parse_obj(get_json_from_hex(pubkey_body.payload))
except json.JSONDecodeError:
raise web.HTTPBadRequest(reason="Not valid JSON payload")

wallet_address: str = json_payload.address

if not verify_wallet_signature(pubkey_body.signature, pubkey_body.payload, wallet_address):
raise web.HTTPUnauthorized(reason="Invalid signature")
signed_operation = request.headers["X-SignedOperation"]
return SignedOperation.parse_raw(signed_operation)
except KeyError as error:
raise web.HTTPBadRequest(reason="Missing X-SignedOperation header") from error
except json.JSONDecodeError as error:
raise web.HTTPBadRequest(reason="Invalid X-SignedOperation format") from error

expires = json_payload.expires
if not expires or not is_token_still_valid(expires):
raise web.HTTPUnauthorized(reason="Token expired")

signed_operation = request.headers.get("X-SignedOperation")
if not signed_operation:
raise web.HTTPBadRequest(reason="Missing X-SignedOperation header")
async def authenticate_jwk(request: web.Request) -> str:
signed_pubkey = get_signed_pubkey(request)
signed_operation = get_signed_operation(request)

json_web_key = Jwk(json_payload.pubkey)
try:
payload_json = json.loads(signed_operation)
request_object = SignedOperation.parse_obj(payload_json)
except json.JSONDecodeError:
raise web.HTTPBadRequest(reason="Could not decode X-SignedOperation")

if json_web_key.verify(
data=bytes.fromhex(request_object.payload),
signature=bytes.fromhex(request_object.signature),
if signed_pubkey.content.json_web_key.verify(
data=signed_operation.payload,
signature=signed_operation.signature,
alg="ES256",
):
logger.debug("Signature verified")
return wallet_address
return signed_pubkey.content.address
else:
raise web.HTTPUnauthorized(reason="Signature could not verified")

Expand All @@ -150,12 +222,12 @@ async def wrapper(request):
def get_itemhash_or_400(match_info: UrlMappingMatchInfo) -> ItemHash:
try:
ref = match_info["ref"]
except KeyError:
raise aiohttp.web_exceptions.HTTPBadRequest(body="Missing field: 'ref'")
except KeyError as error:
raise aiohttp.web_exceptions.HTTPBadRequest(body="Missing field: 'ref'") from error
try:
return ItemHash(ref)
except UnknownHashError:
raise aiohttp.web_exceptions.HTTPBadRequest(body=f"Invalid ref: '{ref}'")
except UnknownHashError as error:
raise aiohttp.web_exceptions.HTTPBadRequest(body=f"Invalid ref: '{ref}'") from error


def get_execution_or_404(ref: ItemHash, pool: VmPool) -> VmExecution:
Expand Down

0 comments on commit b89239f

Please sign in to comment.