diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index ae8e6e7..464e936 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -3,7 +3,7 @@ name: Unit tests on: [push] jobs: - build: + tests: runs-on: ubuntu-22.04 strategy: @@ -20,6 +20,9 @@ jobs: run: | python -m pip install --upgrade pip pip install -e .[testing] + - name: Check typing with mypy + run: | + mypy src/ tests/ - name: Test with pytest run: | pytest diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..5a5420f --- /dev/null +++ b/mypy.ini @@ -0,0 +1,65 @@ +# Global options: + +[mypy] +mypy_path = src + +exclude = conftest.py + + +show_column_numbers = True + +# Suppressing errors +# Shows errors related to strict None checking, if the global strict_optional flag is enabled +strict_optional = True +no_implicit_optional = True + +# Import discovery +# Suppresses error messages about imports that cannot be resolved +ignore_missing_imports = True +# Forces import to reference the original source file +no_implicit_reexport = True +# show error messages from unrelated files +follow_imports = silent +follow_imports_for_stubs = False + + +# Disallow dynamic typing +# Disallows usage of types that come from unfollowed imports +disallow_any_unimported = False +# Disallows all expressions in the module that have type Any +disallow_any_expr = False +# Disallows functions that have Any in their signature after decorator transformation. +disallow_any_decorated = False +# Disallows explicit Any in type positions such as type annotations and generic type parameters. +disallow_any_explicit = False +# Disallows usage of generic types that do not specify explicit type parameters. +disallow_any_generics = False +# Disallows subclassing a value of type Any. +disallow_subclassing_any = False + +# Untyped definitions and calls +# Disallows calling functions without type annotations from functions with type annotations. +disallow_untyped_calls = False +# Disallows defining functions without type annotations or with incomplete type annotations +disallow_untyped_defs = False +# Disallows defining functions with incomplete type annotations. +check_untyped_defs = False +# Type-checks the interior of functions without type annotations. +disallow_incomplete_defs = False +# Reports an error whenever a function with type annotations is decorated with a decorator without annotations. +disallow_untyped_decorators = False + +# Prohibit comparisons of non-overlapping types (ex: 42 == "no") +strict_equality = True + +# Configuring warnings +# Warns about unneeded # type: ignore comments. +warn_unused_ignores = True +# Shows errors for missing return statements on some execution paths. +warn_no_return = True +# Shows a warning when returning a value with type Any from a function declared with a non- Any return type. +warn_return_any = False + +# Miscellaneous strictness flags +# Allows variables to be redefined with an arbitrary type, as long as the redefinition is in the same block and nesting level as the original definition. +allow_redefinition = True diff --git a/setup.cfg b/setup.cfg index ed8acb6..5498224 100644 --- a/setup.cfg +++ b/setup.cfg @@ -67,11 +67,12 @@ exclude = # Add here test requirements (semicolon/line-separated) testing = - setuptools + mypy==1.5.1 pytest pytest-asyncio pytest-cov pytest-mock + setuptools uvicorn [options.entry_points] diff --git a/src/aleph_vrf/coordinator/main.py b/src/aleph_vrf/coordinator/main.py index dda1d55..6794785 100644 --- a/src/aleph_vrf/coordinator/main.py +++ b/src/aleph_vrf/coordinator/main.py @@ -1,4 +1,5 @@ import logging +from typing import Dict, Union logger = logging.getLogger(__name__) @@ -13,7 +14,7 @@ logger.debug("local imports") from aleph_vrf.coordinator.vrf import generate_vrf -from aleph_vrf.models import APIResponse +from aleph_vrf.models import APIResponse, VRFResponse logger.debug("imports done") @@ -37,6 +38,8 @@ async def receive_vrf() -> APIResponse: private_key = get_fallback_private_key() account = ETHAccount(private_key=private_key) + response: Union[VRFResponse, Dict[str, str]] + try: response = await generate_vrf(account) except Exception as err: diff --git a/src/aleph_vrf/coordinator/vrf.py b/src/aleph_vrf/coordinator/vrf.py index 9cd2d7f..70080c8 100644 --- a/src/aleph_vrf/coordinator/vrf.py +++ b/src/aleph_vrf/coordinator/vrf.py @@ -18,10 +18,10 @@ from aleph_vrf.models import ( CRNVRFResponse, Node, - VRFRandomBytes, VRFRequest, VRFResponse, - VRFResponseHash, + PublishedVRFResponseHash, + PublishedVRFRandomBytes, ) from aleph_vrf.settings import settings from aleph_vrf.utils import ( @@ -43,7 +43,7 @@ M = TypeVar("M", bound=BaseModel) -async def post_node_vrf(url: str, model: Type[M]) -> Union[Exception, M]: +async def post_node_vrf(url: str, model: Type[M]) -> M: async with aiohttp.ClientSession() as session: async with session.post(url, timeout=60) as resp: if resp.status != 200: @@ -143,16 +143,16 @@ async def generate_vrf(account: ETHAccount) -> VRFResponse: logger.debug(f"Generated VRF request with item_hash {request_item_hash}") vrf_generated_result = await send_generate_requests( - selected_nodes, request_item_hash + selected_nodes=selected_nodes, + request_item_hash=request_item_hash, + request_id=vrf_request.request_id, ) logger.debug( f"Received VRF generated requests from {len(vrf_generated_result)} nodes" ) - vrf_publish_result = await send_publish_requests( - vrf_generated_result, vrf_request.request_id - ) + vrf_publish_result = await send_publish_requests(vrf_generated_result) logger.debug( f"Received VRF publish requests from {len(vrf_generated_result)} nodes" @@ -178,50 +178,64 @@ async def generate_vrf(account: ETHAccount) -> VRFResponse: async def send_generate_requests( - selected_nodes: List[Node], request_item_hash: str -) -> Dict[str, Union[Exception, VRFResponseHash]]: + selected_nodes: List[Node], + request_item_hash: str, + request_id: str, +) -> Dict[str, PublishedVRFResponseHash]: generate_tasks = [] nodes: List[str] = [] for node in selected_nodes: nodes.append(node.address) url = f"{node.address}/vm/{settings.FUNCTION}/{VRF_FUNCTION_GENERATE_PATH}/{request_item_hash}" - generate_tasks.append(asyncio.create_task(post_node_vrf(url, VRFResponseHash))) + generate_tasks.append( + asyncio.create_task(post_node_vrf(url, PublishedVRFResponseHash)) + ) vrf_generated_responses = await asyncio.gather( *generate_tasks, return_exceptions=True ) - return dict(zip(nodes, vrf_generated_responses)) + generate_results = dict(zip(nodes, vrf_generated_responses)) + for node, result in generate_results.items(): + if isinstance(result, Exception): + raise ValueError( + f"Generate response not found for Node {node} on request_id {request_id}" + ) + + return generate_results async def send_publish_requests( - vrf_generated_result: Dict[str, VRFResponseHash], - request_id: str, -) -> Dict[str, Union[Exception, VRFRandomBytes]]: + vrf_generated_result: Dict[str, PublishedVRFResponseHash], +) -> Dict[str, PublishedVRFRandomBytes]: publish_tasks = [] nodes: List[str] = [] + for node, vrf_generated_response in vrf_generated_result.items(): nodes.append(node) - if isinstance(vrf_generated_response, Exception): - raise ValueError( - f"Generate response not found for Node {node} on request_id {request_id}" - ) node_message_hash = vrf_generated_response.message_hash url = ( f"{node}/vm/{settings.FUNCTION}" f"/{VRF_FUNCTION_PUBLISH_PATH}/{node_message_hash}" ) - publish_tasks.append(asyncio.create_task(post_node_vrf(url, VRFRandomBytes))) + publish_tasks.append( + asyncio.create_task(post_node_vrf(url, PublishedVRFRandomBytes)) + ) vrf_publish_responses = await asyncio.gather(*publish_tasks, return_exceptions=True) - return dict(zip(nodes, vrf_publish_responses)) + publish_results = dict(zip(nodes, vrf_publish_responses)) + for node, result in publish_results.items(): + if isinstance(result, Exception): + raise ValueError(f"Publish response not found for {node}") + + return publish_results def generate_final_vrf( nb_executors: int, nonce: int, - vrf_generated_result: Dict[str, VRFResponseHash], - vrf_publish_result: Dict[str, VRFRandomBytes], + vrf_generated_result: Dict[str, PublishedVRFResponseHash], + vrf_publish_result: Dict[str, PublishedVRFRandomBytes], vrf_request: VRFRequest, ) -> VRFResponse: nodes_responses = [] diff --git a/src/aleph_vrf/executor/main.py b/src/aleph_vrf/executor/main.py index 66e7fca..aca5559 100644 --- a/src/aleph_vrf/executor/main.py +++ b/src/aleph_vrf/executor/main.py @@ -1,7 +1,5 @@ import logging -from contextlib import asynccontextmanager from typing import Dict, Union, Set -from uuid import UUID import fastapi from aleph.sdk.exceptions import MessageNotFoundError, MultipleMessagesError @@ -28,6 +26,8 @@ VRFResponseHash, generate_request_from_message, generate_response_hash_from_message, + PublishedVRFResponseHash, + PublishedVRFRandomBytes, ) from aleph_vrf.utils import bytes_to_binary, bytes_to_int, generate @@ -39,17 +39,7 @@ ANSWERED_REQUESTS: Set[str] = set() SAVED_GENERATED_BYTES: Dict[str, bytes] = {} - -@asynccontextmanager -async def lifespan(app: FastAPI): - global ANSWERED_REQUESTS, SAVED_GENERATED_BYTES - - ANSWERED_REQUESTS.clear() - SAVED_GENERATED_BYTES.clear() - yield - - -http_app = FastAPI(lifespan=lifespan) +http_app = FastAPI() app = AlephApp(http_app=http_app) @@ -80,7 +70,9 @@ async def _get_message(client: AlephClient, item_hash: ItemHash) -> PostMessage: @app.post("/generate/{vrf_request}") -async def receive_generate(vrf_request: ItemHash) -> APIResponse[VRFResponseHash]: +async def receive_generate( + vrf_request: ItemHash, +) -> APIResponse[PublishedVRFResponseHash]: global SAVED_GENERATED_BYTES, ANSWERED_REQUESTS private_key = get_fallback_private_key() @@ -120,13 +112,17 @@ async def receive_generate(vrf_request: ItemHash) -> APIResponse[VRFResponseHash message_hash = await publish_data(response_hash, ref, account) - response_hash.message_hash = message_hash + published_response_hash = PublishedVRFResponseHash.from_vrf_response_hash( + vrf_response_hash=response_hash, message_hash=message_hash + ) - return APIResponse(data=response_hash) + return APIResponse(data=published_response_hash) @app.post("/publish/{hash_message}") -async def receive_publish(hash_message: ItemHash) -> APIResponse[VRFRandomBytes]: +async def receive_publish( + hash_message: ItemHash, +) -> APIResponse[PublishedVRFRandomBytes]: global SAVED_GENERATED_BYTES private_key = get_fallback_private_key() @@ -155,10 +151,11 @@ async def receive_publish(hash_message: ItemHash) -> APIResponse[VRFRandomBytes] ref = f"vrf_{response_hash.request_id}_{response_hash.execution_id}" message_hash = await publish_data(response_bytes, ref, account) + published_random_bytes = PublishedVRFRandomBytes.from_vrf_random_bytes( + vrf_random_bytes=response_bytes, message_hash=message_hash + ) - response_bytes.message_hash = message_hash - - return APIResponse(data=response_bytes) + return APIResponse(data=published_random_bytes) async def publish_data( diff --git a/src/aleph_vrf/models.py b/src/aleph_vrf/models.py index 417e1dc..4334223 100644 --- a/src/aleph_vrf/models.py +++ b/src/aleph_vrf/models.py @@ -1,4 +1,5 @@ -from typing import List, Optional, TypeVar, Generic +from typing import List, Optional +from typing import TypeVar, Generic from uuid import uuid4 import fastapi @@ -48,10 +49,34 @@ class VRFResponseHash(BaseModel): execution_id: str vrf_request: ItemHash random_bytes_hash: str - message_hash: Optional[str] = None -def generate_response_hash_from_message(message: PostMessage) -> VRFResponseHash: +class PublishedVRFResponseHash(VRFResponseHash): + """ + A VRF response hash already published on aleph.im. + Includes the hash of the message published on aleph.im. + """ + + message_hash: ItemHash + + @classmethod + def from_vrf_response_hash( + cls, vrf_response_hash: VRFResponseHash, message_hash: ItemHash + ) -> "PublishedVRFResponseHash": + return cls( + nb_bytes=vrf_response_hash.nb_bytes, + nonce=vrf_response_hash.nonce, + request_id=vrf_response_hash.request_id, + execution_id=vrf_response_hash.execution_id, + vrf_request=vrf_response_hash.vrf_request, + random_bytes_hash=vrf_response_hash.random_bytes_hash, + message_hash=message_hash, + ) + + +def generate_response_hash_from_message( + message: PostMessage, +) -> PublishedVRFResponseHash: content = message.content.content try: response_hash = VRFResponseHash.parse_obj(content) @@ -61,8 +86,9 @@ def generate_response_hash_from_message(message: PostMessage) -> VRFResponseHash detail=f"Could not parse content of {message.item_hash} as VRF response hash object: {e.json()}", ) - response_hash.message_hash = message.item_hash - return response_hash + return PublishedVRFResponseHash.from_vrf_response_hash( + vrf_response_hash=response_hash, message_hash=message.item_hash + ) class VRFRandomBytes(BaseModel): @@ -72,7 +98,24 @@ class VRFRandomBytes(BaseModel): random_bytes: str random_bytes_hash: str random_number: str - message_hash: Optional[str] = None + + +class PublishedVRFRandomBytes(VRFRandomBytes): + message_hash: ItemHash + + @classmethod + def from_vrf_random_bytes( + cls, vrf_random_bytes: VRFRandomBytes, message_hash: ItemHash + ) -> "PublishedVRFRandomBytes": + return cls( + request_id=vrf_random_bytes.request_id, + execution_id=vrf_random_bytes.execution_id, + vrf_request=vrf_random_bytes.vrf_request, + random_bytes=vrf_random_bytes.random_bytes, + random_bytes_hash=vrf_random_bytes.random_bytes_hash, + random_number=vrf_random_bytes.random_number, + message_hash=message_hash, + ) class CRNVRFResponse(BaseModel): diff --git a/src/aleph_vrf/settings.py b/src/aleph_vrf/settings.py index ffa2ffc..b563d11 100644 --- a/src/aleph_vrf/settings.py +++ b/src/aleph_vrf/settings.py @@ -3,23 +3,23 @@ class Settings(BaseSettings): API_HOST: str = Field( - "https://api2.aleph.im", + default="https://api2.aleph.im", description="URL of the reference aleph.im Core Channel Node.", ) CORECHANNEL_AGGREGATE_ADDRESS = Field( - "0xa1B3bb7d2332383D96b7796B908fB7f7F3c2Be10", + default="0xa1B3bb7d2332383D96b7796B908fB7f7F3c2Be10", description="Address posting the `corechannel` aggregate.", ) CORECHANNEL_AGGREGATE_KEY = Field( - "corechannel", description="Key for the `corechannel` aggregate." + default="corechannel", description="Key for the `corechannel` aggregate." ) FUNCTION: str = Field( - "4992b4127d296b240bbb73058daea9bca09f717fa94767d6f4dc3ef53b4ef5ce", + default="4992b4127d296b240bbb73058daea9bca09f717fa94767d6f4dc3ef53b4ef5ce", description="VRF function to use.", ) - NB_EXECUTORS: int = Field(32, description="Number of executors to use.") + NB_EXECUTORS: int = Field(default=32, description="Number of executors to use.") NB_BYTES: int = Field( - 32, description="Number of bytes of the generated random number." + default=32, description="Number of bytes of the generated random number." ) class Config: diff --git a/src/aleph_vrf/utils.py b/src/aleph_vrf/utils.py index e6fc628..49438dd 100644 --- a/src/aleph_vrf/utils.py +++ b/src/aleph_vrf/utils.py @@ -1,6 +1,6 @@ from hashlib import sha3_256 from random import randint -from typing import List +from typing import List, Tuple from utilitybelt import dev_urandom_entropy @@ -38,9 +38,9 @@ def generate_nonce() -> int: return randint(0, 100000000) -def generate(n: int, nonce: int) -> (bytes, bytes): +def generate(n: int, nonce: int) -> Tuple[bytes, str]: """Generates a number of random bytes and hashes them with the nonce.""" - random_bytes = dev_urandom_entropy(n) + random_bytes: bytes = dev_urandom_entropy(n) random_hash = sha3_256(random_bytes + int_to_bytes(nonce)).hexdigest() return random_bytes, random_hash diff --git a/tests/executor/test_integration.py b/tests/executor/test_integration.py index 7d0ee34..0d0b60d 100644 --- a/tests/executor/test_integration.py +++ b/tests/executor/test_integration.py @@ -15,7 +15,14 @@ PostMessage, ) -from aleph_vrf.models import VRFRequest, VRFResponseHash, VRFResponse, VRFRandomBytes +from aleph_vrf.models import ( + VRFRequest, + VRFResponseHash, + VRFResponse, + VRFRandomBytes, + PublishedVRFResponseHash, + PublishedVRFRandomBytes, +) from aleph_vrf.utils import binary_to_bytes, verify @@ -91,7 +98,9 @@ async def published_vrf_request( def assert_vrf_hash_matches_request( - response_hash: VRFResponseHash, vrf_request: VRFRequest, request_item_hash: ItemHash + response_hash: PublishedVRFResponseHash, + vrf_request: VRFRequest, + request_item_hash: ItemHash, ): assert response_hash.nb_bytes == vrf_request.nb_bytes assert response_hash.nonce == vrf_request.nonce @@ -99,7 +108,6 @@ def assert_vrf_hash_matches_request( assert response_hash.execution_id # This should be a UUID4 assert response_hash.vrf_request == request_item_hash assert response_hash.random_bytes_hash - assert response_hash.message_hash def assert_random_number_matches_request( @@ -132,7 +140,8 @@ def assert_vrf_response_hash_equal( async def assert_aleph_message_matches_response_hash( - ccn_url: str, response_hash: VRFResponseHash + ccn_url: Any, # aiohttp does not expose its URL type + response_hash: PublishedVRFResponseHash, ) -> PostMessage: assert response_hash.message_hash @@ -160,10 +169,9 @@ def assert_vrf_random_bytes_equal( async def assert_aleph_message_matches_random_bytes( - ccn_url: str, random_bytes: VRFRandomBytes + ccn_url: Any, # aiohttp does not expose its URL type + random_bytes: PublishedVRFRandomBytes, ) -> PostMessage: - assert random_bytes.message_hash - async with AlephClient(api_server=ccn_url) as client: message = await client.get_message( random_bytes.message_hash, message_type=PostMessage @@ -195,7 +203,7 @@ async def test_normal_request_flow( assert resp.status == 200, await resp.text() response_json = await resp.json() - response_hash = VRFResponseHash.parse_obj(response_json["data"]) + response_hash = PublishedVRFResponseHash.parse_obj(response_json["data"]) assert_vrf_hash_matches_request(response_hash, vrf_request, item_hash) random_hash_message = await assert_aleph_message_matches_response_hash( @@ -206,7 +214,7 @@ async def test_normal_request_flow( assert resp.status == 200, await resp.text() response_json = await resp.json() - random_bytes = VRFRandomBytes.parse_obj(response_json["data"]) + random_bytes = PublishedVRFRandomBytes.parse_obj(response_json["data"]) assert_random_number_matches_request( random_bytes=random_bytes, response_hash=response_hash, @@ -238,7 +246,7 @@ async def test_call_publish_twice( assert resp.status == 200, await resp.text() response_json = await resp.json() - response_hash = VRFResponseHash.parse_obj(response_json["data"]) + response_hash = PublishedVRFResponseHash.parse_obj(response_json["data"]) # Call POST /publish a first time resp = await executor_client.post(f"/publish/{response_hash.message_hash}") diff --git a/tests/mock_ccn.py b/tests/mock_ccn.py index 7a5a278..726ed5a 100644 --- a/tests/mock_ccn.py +++ b/tests/mock_ccn.py @@ -3,6 +3,7 @@ from enum import Enum from typing import Optional, Dict, Any, List +from aleph_message.models import ItemHash from aleph_message.status import MessageStatus from fastapi import FastAPI from pydantic import BaseModel, Field @@ -12,12 +13,12 @@ app = FastAPI() -MESSAGES = {} +MESSAGES: Dict[ItemHash, Dict[str, Any]] = {} @app.get("/api/v0/messages.json") async def get_messages(hashes: Optional[str], page: int = 1, pagination: int = 20): - hashes = hashes.split(",") + hashes = [ItemHash(h) for h in hashes.split(",")] if hashes is not None else [] messages = [MESSAGES[item_hash] for item_hash in hashes if item_hash in MESSAGES] paginated_messages = messages[(page - 1) * pagination : page * pagination]