Skip to content

Commit

Permalink
Internal: check typing with mypy
Browse files Browse the repository at this point in the history
Problem: the codebase has type hints but we do not check them in the CI
pipeline.

Solution: add a type check step with mypy.

Fixed the following issues:
- the generate/publish steps now check for exceptions before returning
  instead of checking in the next function. This allows to simplify the
  type hints for return values.
- The response hash/random bytes are now split in published/unpublished
  versions to ensure that the message hash is present in the API
  responses.
  • Loading branch information
odesenfans committed Sep 22, 2023
1 parent 48cf72d commit f35a71b
Show file tree
Hide file tree
Showing 9 changed files with 179 additions and 41 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ name: Unit tests
on: [push]

jobs:
build:
tests:

runs-on: ubuntu-22.04
strategy:
Expand All @@ -20,6 +20,9 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -e .[testing]
- name: Check typing with mypy
run: |
mypy src/ tests/
- name: Test with pytest
run: |
pytest
69 changes: 69 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Global options:

[mypy]
mypy_path = src

exclude = conftest.py


show_column_numbers = True

# Suppressing errors
# Shows errors related to strict None checking, if the global strict_optional flag is enabled
strict_optional = True
no_implicit_optional = True

# Import discovery
# Suppresses error messages about imports that cannot be resolved
ignore_missing_imports = True
# Forces import to reference the original source file
no_implicit_reexport = True
# show error messages from unrelated files
follow_imports = silent
follow_imports_for_stubs = False


# Disallow dynamic typing
# Disallows usage of types that come from unfollowed imports
disallow_any_unimported = False
# Disallows all expressions in the module that have type Any
disallow_any_expr = False
# Disallows functions that have Any in their signature after decorator transformation.
disallow_any_decorated = False
# Disallows explicit Any in type positions such as type annotations and generic type parameters.
disallow_any_explicit = False
# Disallows usage of generic types that do not specify explicit type parameters.
disallow_any_generics = False
# Disallows subclassing a value of type Any.
disallow_subclassing_any = False

# Untyped definitions and calls
# Disallows calling functions without type annotations from functions with type annotations.
disallow_untyped_calls = False
# Disallows defining functions without type annotations or with incomplete type annotations
disallow_untyped_defs = False
# Disallows defining functions with incomplete type annotations.
check_untyped_defs = False
# Type-checks the interior of functions without type annotations.
disallow_incomplete_defs = False
# Reports an error whenever a function with type annotations is decorated with a decorator without annotations.
disallow_untyped_decorators = False

# Prohibit comparisons of non-overlapping types (ex: 42 == "no")
strict_equality = True

# Configuring warnings
# Warns about unneeded # type: ignore comments.
warn_unused_ignores = True
# Shows errors for missing return statements on some execution paths.
warn_no_return = True
# Shows a warning when returning a value with type Any from a function declared with a non- Any return type.
warn_return_any = False

# Miscellaneous strictness flags
# Allows variables to be redefined with an arbitrary type, as long as the redefinition is in the same block and nesting level as the original definition.
allow_redefinition = True

# Ignore the imported code from py-libp2p
[mypy-aleph.toolkit.libp2p_stubs.*]
ignore_errors = True
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ testing =
pytest-asyncio
pytest-cov
pytest-mock
mypy==1.5.1

[options.entry_points]
# Add here console scripts like:
Expand Down
5 changes: 4 additions & 1 deletion src/aleph_vrf/coordinator/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Dict, Union

logger = logging.getLogger(__name__)

Expand All @@ -13,7 +14,7 @@

logger.debug("local imports")
from aleph_vrf.coordinator.vrf import generate_vrf
from aleph_vrf.models import APIResponse
from aleph_vrf.models import APIResponse, VRFResponse

logger.debug("imports done")

Expand All @@ -37,6 +38,8 @@ async def receive_vrf() -> APIResponse:
private_key = get_fallback_private_key()
account = ETHAccount(private_key=private_key)

response: Union[VRFResponse, Dict[str, str]]

try:
response = await generate_vrf(account)
except Exception as err:
Expand Down
56 changes: 36 additions & 20 deletions src/aleph_vrf/coordinator/vrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
VRFRequest,
VRFResponse,
VRFResponseHash,
PublishedVRFResponseHash,
PublishedVRFRandomBytes,
)
from aleph_vrf.settings import settings
from aleph_vrf.utils import (
Expand All @@ -43,7 +45,7 @@
M = TypeVar("M", bound=BaseModel)


async def post_node_vrf(url: str, model: Type[M]) -> Union[Exception, M]:
async def post_node_vrf(url: str, model: Type[M]) -> M:
async with aiohttp.ClientSession() as session:
async with session.post(url, timeout=60) as resp:
if resp.status != 200:
Expand Down Expand Up @@ -143,16 +145,16 @@ async def generate_vrf(account: ETHAccount) -> VRFResponse:
logger.debug(f"Generated VRF request with item_hash {request_item_hash}")

vrf_generated_result = await send_generate_requests(
selected_nodes, request_item_hash
selected_nodes=selected_nodes,
request_item_hash=request_item_hash,
request_id=vrf_request.request_id,
)

logger.debug(
f"Received VRF generated requests from {len(vrf_generated_result)} nodes"
)

vrf_publish_result = await send_publish_requests(
vrf_generated_result, vrf_request.request_id
)
vrf_publish_result = await send_publish_requests(vrf_generated_result)

logger.debug(
f"Received VRF publish requests from {len(vrf_generated_result)} nodes"
Expand All @@ -178,50 +180,64 @@ async def generate_vrf(account: ETHAccount) -> VRFResponse:


async def send_generate_requests(
selected_nodes: List[Node], request_item_hash: str
) -> Dict[str, Union[Exception, VRFResponseHash]]:
selected_nodes: List[Node],
request_item_hash: str,
request_id: str,
) -> Dict[str, PublishedVRFResponseHash]:
generate_tasks = []
nodes: List[str] = []
for node in selected_nodes:
nodes.append(node.address)
url = f"{node.address}/vm/{settings.FUNCTION}/{VRF_FUNCTION_GENERATE_PATH}/{request_item_hash}"
generate_tasks.append(asyncio.create_task(post_node_vrf(url, VRFResponseHash)))
generate_tasks.append(
asyncio.create_task(post_node_vrf(url, PublishedVRFResponseHash))
)

vrf_generated_responses = await asyncio.gather(
*generate_tasks, return_exceptions=True
)
return dict(zip(nodes, vrf_generated_responses))
generate_results = dict(zip(nodes, vrf_generated_responses))
for node, result in generate_results.items():
if isinstance(result, Exception):
raise ValueError(
f"Generate response not found for Node {node} on request_id {request_id}"
)

return generate_results


async def send_publish_requests(
vrf_generated_result: Dict[str, VRFResponseHash],
request_id: str,
) -> Dict[str, Union[Exception, VRFRandomBytes]]:
vrf_generated_result: Dict[str, PublishedVRFResponseHash],
) -> Dict[str, PublishedVRFRandomBytes]:
publish_tasks = []
nodes: List[str] = []

for node, vrf_generated_response in vrf_generated_result.items():
nodes.append(node)
if isinstance(vrf_generated_response, Exception):
raise ValueError(
f"Generate response not found for Node {node} on request_id {request_id}"
)

node_message_hash = vrf_generated_response.message_hash
url = (
f"{node}/vm/{settings.FUNCTION}"
f"/{VRF_FUNCTION_PUBLISH_PATH}/{node_message_hash}"
)
publish_tasks.append(asyncio.create_task(post_node_vrf(url, VRFRandomBytes)))
publish_tasks.append(
asyncio.create_task(post_node_vrf(url, PublishedVRFRandomBytes))
)

vrf_publish_responses = await asyncio.gather(*publish_tasks, return_exceptions=True)
return dict(zip(nodes, vrf_publish_responses))
publish_results = dict(zip(nodes, vrf_publish_responses))
for node, result in publish_results.items():
if isinstance(result, Exception):
raise ValueError(f"Publish response not found for {node}")

return publish_results


def generate_final_vrf(
nb_executors: int,
nonce: int,
vrf_generated_result: Dict[str, VRFResponseHash],
vrf_publish_result: Dict[str, VRFRandomBytes],
vrf_generated_result: Dict[str, PublishedVRFResponseHash],
vrf_publish_result: Dict[str, PublishedVRFRandomBytes],
vrf_request: VRFRequest,
) -> VRFResponse:
nodes_responses = []
Expand Down
15 changes: 10 additions & 5 deletions src/aleph_vrf/executor/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
VRFResponseHash,
generate_request_from_message,
generate_response_hash_from_message,
PublishedVRFResponseHash,
PublishedVRFRandomBytes,
)
from aleph_vrf.utils import bytes_to_binary, bytes_to_int, generate

Expand Down Expand Up @@ -79,9 +81,11 @@ async def receive_generate(vrf_request: str) -> APIResponse:

message_hash = await publish_data(response_hash, ref, account)

response_hash.message_hash = message_hash
published_response_hash = PublishedVRFResponseHash.from_vrf_response_hash(
vrf_response_hash=response_hash, message_hash=message_hash
)

return APIResponse(data=response_hash)
return APIResponse(data=published_response_hash)


@app.post("/publish/{hash_message}")
Expand Down Expand Up @@ -114,10 +118,11 @@ async def receive_publish(hash_message: str) -> APIResponse:
ref = f"vrf_{response_hash.request_id}_{response_hash.execution_id}"

message_hash = await publish_data(response_bytes, ref, account)
published_random_bytes = PublishedVRFRandomBytes.from_vrf_random_bytes(
vrf_random_bytes=response_bytes, message_hash=message_hash
)

response_bytes.message_hash = message_hash

return APIResponse(data=response_bytes)
return APIResponse(data=published_random_bytes)


async def publish_data(
Expand Down
51 changes: 46 additions & 5 deletions src/aleph_vrf/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,43 @@ class VRFResponseHash(BaseModel):
execution_id: str
vrf_request: ItemHash
random_bytes_hash: str
message_hash: Optional[str] = None


def generate_response_hash_from_message(message: PostMessage) -> VRFResponseHash:
class PublishedVRFResponseHash(VRFResponseHash):
"""
A VRF response hash already published on aleph.im.
Includes the hash of the message published on aleph.im.
"""

message_hash: ItemHash

@classmethod
def from_vrf_response_hash(
cls, vrf_response_hash: VRFResponseHash, message_hash: ItemHash
) -> "PublishedVRFResponseHash":
return cls(
nb_bytes=vrf_response_hash.nb_bytes,
nonce=vrf_response_hash.nonce,
request_id=vrf_response_hash.request_id,
execution_id=vrf_response_hash.execution_id,
vrf_request=vrf_response_hash.request_id,
random_bytes_hash=vrf_response_hash.random_bytes_hash,
message_hash=message_hash,
)


def generate_response_hash_from_message(
message: PostMessage,
) -> PublishedVRFResponseHash:
content = message.content.content
return VRFResponseHash(
return PublishedVRFResponseHash(
nb_bytes=content["nb_bytes"],
nonce=content["nonce"],
request_id=content["request_id"],
execution_id=content["execution_id"],
vrf_request=ItemHash(content["vrf_request"]),
random_bytes_hash=content["random_bytes_hash"],
message_hash=content["message_hash"],
message_hash=message.item_hash,
)


Expand All @@ -69,7 +93,24 @@ class VRFRandomBytes(BaseModel):
random_bytes: str
random_bytes_hash: str
random_number: str
message_hash: Optional[str] = None


class PublishedVRFRandomBytes(VRFRandomBytes):
message_hash: ItemHash

@classmethod
def from_vrf_random_bytes(
cls, vrf_random_bytes: VRFRandomBytes, message_hash: ItemHash
) -> "PublishedVRFRandomBytes":
return cls(
request_id=vrf_random_bytes.request_id,
execution_id=vrf_random_bytes.execution_id,
vrf_request=vrf_random_bytes.request_id,
random_bytes=vrf_random_bytes.random_bytes,
random_bytes_hash=vrf_random_bytes.random_bytes_hash,
random_number=vrf_random_bytes.random_number,
message_hash=message_hash,
)


class CRNVRFResponse(BaseModel):
Expand Down
12 changes: 6 additions & 6 deletions src/aleph_vrf/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@

class Settings(BaseSettings):
API_HOST: str = Field(
"https://api2.aleph.im",
default="https://api2.aleph.im",
description="URL of the reference aleph.im Core Channel Node.",
)
CORECHANNEL_AGGREGATE_ADDRESS = Field(
"0xa1B3bb7d2332383D96b7796B908fB7f7F3c2Be10",
default="0xa1B3bb7d2332383D96b7796B908fB7f7F3c2Be10",
description="Address posting the `corechannel` aggregate.",
)
CORECHANNEL_AGGREGATE_KEY = Field(
"corechannel", description="Key for the `corechannel` aggregate."
default="corechannel", description="Key for the `corechannel` aggregate."
)
FUNCTION: str = Field(
"4992b4127d296b240bbb73058daea9bca09f717fa94767d6f4dc3ef53b4ef5ce",
default="4992b4127d296b240bbb73058daea9bca09f717fa94767d6f4dc3ef53b4ef5ce",
description="VRF function to use.",
)
NB_EXECUTORS: int = Field(32, description="Number of executors to use.")
NB_EXECUTORS: int = Field(default=32, description="Number of executors to use.")
NB_BYTES: int = Field(
32, description="Number of bytes of the generated random number."
default=32, description="Number of bytes of the generated random number."
)

class Config:
Expand Down
Loading

0 comments on commit f35a71b

Please sign in to comment.