Skip to content

Commit

Permalink
Internal: check typing with mypy
Browse files Browse the repository at this point in the history
Problem: the codebase has type hints but we do not check them in the CI
pipeline.

Solution: add a type check step with mypy.

Fixed the following issues:
- the generate/publish steps now check for exceptions before returning
  instead of checking in the next function. This allows to simplify the
  type hints for return values.
- The response hash/random bytes are now split in published/unpublished
  versions to ensure that the message hash is present in the API
  responses.
  • Loading branch information
odesenfans committed Sep 22, 2023
1 parent e833fab commit 36282b7
Show file tree
Hide file tree
Showing 11 changed files with 207 additions and 72 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Unit tests
on: [push]

jobs:
build:
tests:

runs-on: ubuntu-22.04
strategy:
Expand All @@ -20,6 +20,9 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -e .[testing]
- name: Check typing with mypy
run: |
mypy src/ tests/
- name: Test with pytest
run: |
pytest
65 changes: 65 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Global options:

[mypy]
mypy_path = src

exclude = conftest.py


show_column_numbers = True

# Suppressing errors
# Shows errors related to strict None checking, if the global strict_optional flag is enabled
strict_optional = True
no_implicit_optional = True

# Import discovery
# Suppresses error messages about imports that cannot be resolved
ignore_missing_imports = True
# Forces import to reference the original source file
no_implicit_reexport = True
# show error messages from unrelated files
follow_imports = silent
follow_imports_for_stubs = False


# Disallow dynamic typing
# Disallows usage of types that come from unfollowed imports
disallow_any_unimported = False
# Disallows all expressions in the module that have type Any
disallow_any_expr = False
# Disallows functions that have Any in their signature after decorator transformation.
disallow_any_decorated = False
# Disallows explicit Any in type positions such as type annotations and generic type parameters.
disallow_any_explicit = False
# Disallows usage of generic types that do not specify explicit type parameters.
disallow_any_generics = False
# Disallows subclassing a value of type Any.
disallow_subclassing_any = False

# Untyped definitions and calls
# Disallows calling functions without type annotations from functions with type annotations.
disallow_untyped_calls = False
# Disallows defining functions without type annotations or with incomplete type annotations
disallow_untyped_defs = False
# Disallows defining functions with incomplete type annotations.
check_untyped_defs = False
# Type-checks the interior of functions without type annotations.
disallow_incomplete_defs = False
# Reports an error whenever a function with type annotations is decorated with a decorator without annotations.
disallow_untyped_decorators = False

# Prohibit comparisons of non-overlapping types (ex: 42 == "no")
strict_equality = True

# Configuring warnings
# Warns about unneeded # type: ignore comments.
warn_unused_ignores = True
# Shows errors for missing return statements on some execution paths.
warn_no_return = True
# Shows a warning when returning a value with type Any from a function declared with a non- Any return type.
warn_return_any = False

# Miscellaneous strictness flags
# Allows variables to be redefined with an arbitrary type, as long as the redefinition is in the same block and nesting level as the original definition.
allow_redefinition = True
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,12 @@ exclude =

# Add here test requirements (semicolon/line-separated)
testing =
setuptools
mypy==1.5.1
pytest
pytest-asyncio
pytest-cov
pytest-mock
setuptools
uvicorn

[options.entry_points]
Expand Down
5 changes: 4 additions & 1 deletion src/aleph_vrf/coordinator/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Dict, Union

logger = logging.getLogger(__name__)

Expand All @@ -13,7 +14,7 @@

logger.debug("local imports")
from aleph_vrf.coordinator.vrf import generate_vrf
from aleph_vrf.models import APIResponse
from aleph_vrf.models import APIResponse, VRFResponse

logger.debug("imports done")

Expand All @@ -37,6 +38,8 @@ async def receive_vrf() -> APIResponse:
private_key = get_fallback_private_key()
account = ETHAccount(private_key=private_key)

response: Union[VRFResponse, Dict[str, str]]

try:
response = await generate_vrf(account)
except Exception as err:
Expand Down
58 changes: 36 additions & 22 deletions src/aleph_vrf/coordinator/vrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from aleph_vrf.models import (
CRNVRFResponse,
Node,
VRFRandomBytes,
VRFRequest,
VRFResponse,
VRFResponseHash,
PublishedVRFResponseHash,
PublishedVRFRandomBytes,
)
from aleph_vrf.settings import settings
from aleph_vrf.utils import (
Expand All @@ -43,7 +43,7 @@
M = TypeVar("M", bound=BaseModel)


async def post_node_vrf(url: str, model: Type[M]) -> Union[Exception, M]:
async def post_node_vrf(url: str, model: Type[M]) -> M:
async with aiohttp.ClientSession() as session:
async with session.post(url, timeout=60) as resp:
if resp.status != 200:
Expand Down Expand Up @@ -143,16 +143,16 @@ async def generate_vrf(account: ETHAccount) -> VRFResponse:
logger.debug(f"Generated VRF request with item_hash {request_item_hash}")

vrf_generated_result = await send_generate_requests(
selected_nodes, request_item_hash
selected_nodes=selected_nodes,
request_item_hash=request_item_hash,
request_id=vrf_request.request_id,
)

logger.debug(
f"Received VRF generated requests from {len(vrf_generated_result)} nodes"
)

vrf_publish_result = await send_publish_requests(
vrf_generated_result, vrf_request.request_id
)
vrf_publish_result = await send_publish_requests(vrf_generated_result)

logger.debug(
f"Received VRF publish requests from {len(vrf_generated_result)} nodes"
Expand All @@ -178,50 +178,64 @@ async def generate_vrf(account: ETHAccount) -> VRFResponse:


async def send_generate_requests(
selected_nodes: List[Node], request_item_hash: str
) -> Dict[str, Union[Exception, VRFResponseHash]]:
selected_nodes: List[Node],
request_item_hash: str,
request_id: str,
) -> Dict[str, PublishedVRFResponseHash]:
generate_tasks = []
nodes: List[str] = []
for node in selected_nodes:
nodes.append(node.address)
url = f"{node.address}/vm/{settings.FUNCTION}/{VRF_FUNCTION_GENERATE_PATH}/{request_item_hash}"
generate_tasks.append(asyncio.create_task(post_node_vrf(url, VRFResponseHash)))
generate_tasks.append(
asyncio.create_task(post_node_vrf(url, PublishedVRFResponseHash))
)

vrf_generated_responses = await asyncio.gather(
*generate_tasks, return_exceptions=True
)
return dict(zip(nodes, vrf_generated_responses))
generate_results = dict(zip(nodes, vrf_generated_responses))
for node, result in generate_results.items():
if isinstance(result, Exception):
raise ValueError(
f"Generate response not found for Node {node} on request_id {request_id}"
)

return generate_results


async def send_publish_requests(
vrf_generated_result: Dict[str, VRFResponseHash],
request_id: str,
) -> Dict[str, Union[Exception, VRFRandomBytes]]:
vrf_generated_result: Dict[str, PublishedVRFResponseHash],
) -> Dict[str, PublishedVRFRandomBytes]:
publish_tasks = []
nodes: List[str] = []

for node, vrf_generated_response in vrf_generated_result.items():
nodes.append(node)
if isinstance(vrf_generated_response, Exception):
raise ValueError(
f"Generate response not found for Node {node} on request_id {request_id}"
)

node_message_hash = vrf_generated_response.message_hash
url = (
f"{node}/vm/{settings.FUNCTION}"
f"/{VRF_FUNCTION_PUBLISH_PATH}/{node_message_hash}"
)
publish_tasks.append(asyncio.create_task(post_node_vrf(url, VRFRandomBytes)))
publish_tasks.append(
asyncio.create_task(post_node_vrf(url, PublishedVRFRandomBytes))
)

vrf_publish_responses = await asyncio.gather(*publish_tasks, return_exceptions=True)
return dict(zip(nodes, vrf_publish_responses))
publish_results = dict(zip(nodes, vrf_publish_responses))
for node, result in publish_results.items():
if isinstance(result, Exception):
raise ValueError(f"Publish response not found for {node}")

return publish_results


def generate_final_vrf(
nb_executors: int,
nonce: int,
vrf_generated_result: Dict[str, VRFResponseHash],
vrf_publish_result: Dict[str, VRFRandomBytes],
vrf_generated_result: Dict[str, PublishedVRFResponseHash],
vrf_publish_result: Dict[str, PublishedVRFRandomBytes],
vrf_request: VRFRequest,
) -> VRFResponse:
nodes_responses = []
Expand Down
37 changes: 17 additions & 20 deletions src/aleph_vrf/executor/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import logging
from contextlib import asynccontextmanager
from typing import Dict, Union, Set
from uuid import UUID

import fastapi
from aleph.sdk.exceptions import MessageNotFoundError, MultipleMessagesError
Expand All @@ -28,6 +26,8 @@
VRFResponseHash,
generate_request_from_message,
generate_response_hash_from_message,
PublishedVRFResponseHash,
PublishedVRFRandomBytes,
)
from aleph_vrf.utils import bytes_to_binary, bytes_to_int, generate

Expand All @@ -39,17 +39,7 @@
ANSWERED_REQUESTS: Set[str] = set()
SAVED_GENERATED_BYTES: Dict[str, bytes] = {}


@asynccontextmanager
async def lifespan(app: FastAPI):
global ANSWERED_REQUESTS, SAVED_GENERATED_BYTES

ANSWERED_REQUESTS.clear()
SAVED_GENERATED_BYTES.clear()
yield


http_app = FastAPI(lifespan=lifespan)
http_app = FastAPI()
app = AlephApp(http_app=http_app)


Expand Down Expand Up @@ -80,7 +70,9 @@ async def _get_message(client: AlephClient, item_hash: ItemHash) -> PostMessage:


@app.post("/generate/{vrf_request}")
async def receive_generate(vrf_request: ItemHash) -> APIResponse[VRFResponseHash]:
async def receive_generate(
vrf_request: ItemHash,
) -> APIResponse[PublishedVRFResponseHash]:
global SAVED_GENERATED_BYTES, ANSWERED_REQUESTS

private_key = get_fallback_private_key()
Expand Down Expand Up @@ -120,13 +112,17 @@ async def receive_generate(vrf_request: ItemHash) -> APIResponse[VRFResponseHash

message_hash = await publish_data(response_hash, ref, account)

response_hash.message_hash = message_hash
published_response_hash = PublishedVRFResponseHash.from_vrf_response_hash(
vrf_response_hash=response_hash, message_hash=message_hash
)

return APIResponse(data=response_hash)
return APIResponse(data=published_response_hash)


@app.post("/publish/{hash_message}")
async def receive_publish(hash_message: ItemHash) -> APIResponse[VRFRandomBytes]:
async def receive_publish(
hash_message: ItemHash,
) -> APIResponse[PublishedVRFRandomBytes]:
global SAVED_GENERATED_BYTES

private_key = get_fallback_private_key()
Expand Down Expand Up @@ -155,10 +151,11 @@ async def receive_publish(hash_message: ItemHash) -> APIResponse[VRFRandomBytes]
ref = f"vrf_{response_hash.request_id}_{response_hash.execution_id}"

message_hash = await publish_data(response_bytes, ref, account)
published_random_bytes = PublishedVRFRandomBytes.from_vrf_random_bytes(
vrf_random_bytes=response_bytes, message_hash=message_hash
)

response_bytes.message_hash = message_hash

return APIResponse(data=response_bytes)
return APIResponse(data=published_random_bytes)


async def publish_data(
Expand Down
Loading

0 comments on commit 36282b7

Please sign in to comment.