Skip to content

Commit

Permalink
Internal: introduce an Executor class
Browse files Browse the repository at this point in the history
Problem: the `Node` class assumes that the coordinator will use the
aleph.im compute resource nodes in all cases. We wish to let the user
specify dedicated executor servers as well.

Solution: introduce an `Executor` class that takes a node and an
optional VM function. If running on aleph.im, the API URL of the
executor will be https://{node_url}/vm/{function}, otherwise it will
default to https://{node_url}.

The result dictionaries now use executors as keys as the model is now
hashable.
  • Loading branch information
odesenfans committed Sep 25, 2023
1 parent 681ae5f commit c03a823
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 60 deletions.
83 changes: 38 additions & 45 deletions src/aleph_vrf/coordinator/vrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
VRFResponse,
PublishedVRFResponseHash,
PublishedVRFRandomBytes,
VRFResponseHash,
Executor,
)
from aleph_vrf.settings import settings
from aleph_vrf.types import RequestId, Nonce
Expand Down Expand Up @@ -79,10 +81,10 @@ def _get_unauthorized_node_list() -> List[str]:
return []


async def select_random_nodes(
async def select_random_executors(
node_amount: int, unauthorized_nodes: List[str]
) -> List[Node]:
node_list: List[Node] = []
) -> List[Executor]:
node_list: List[Executor] = []

content = await _get_corechannel_aggregate()

Expand All @@ -107,7 +109,7 @@ async def select_random_nodes(
address=node_address,
score=resource_node["score"],
)
node_list.append(node)
node_list.append(Executor(node=node, vm_function=settings.FUNCTION))

if len(node_list) < node_amount:
raise ValueError(
Expand All @@ -121,7 +123,7 @@ async def select_random_nodes(
async def generate_vrf(account: ETHAccount) -> VRFResponse:
nb_executors = settings.NB_EXECUTORS
unauthorized_nodes = _get_unauthorized_node_list()
selected_nodes = await select_random_nodes(nb_executors, unauthorized_nodes)
selected_nodes = await select_random_executors(nb_executors, unauthorized_nodes)
selected_node_list = json.dumps(selected_nodes, default=pydantic_encoder).encode(
encoding="utf-8"
)
Expand All @@ -144,19 +146,19 @@ async def generate_vrf(account: ETHAccount) -> VRFResponse:
logger.debug(f"Generated VRF request with item_hash {request_item_hash}")

vrf_generated_result = await send_generate_requests(
selected_nodes=selected_nodes,
executors=selected_nodes,
request_item_hash=request_item_hash,
request_id=vrf_request.request_id,
)

logger.debug(
f"Received VRF generated requests from {len(vrf_generated_result)} nodes"
f"Received VRF generated requests from {len(vrf_generated_result)} executors"
)

vrf_publish_result = await send_publish_requests(vrf_generated_result)

logger.debug(
f"Received VRF publish requests from {len(vrf_generated_result)} nodes"
f"Received VRF publish requests from {len(vrf_generated_result)} executors"
)

vrf_response = generate_final_vrf(
Expand All @@ -179,74 +181,65 @@ async def generate_vrf(account: ETHAccount) -> VRFResponse:


async def send_generate_requests(
selected_nodes: List[Node],
executors: List[Executor],
request_item_hash: ItemHash,
request_id: RequestId,
) -> Dict[str, PublishedVRFResponseHash]:
) -> Dict[Executor, PublishedVRFResponseHash]:

generate_tasks = []
nodes: List[str] = []
for node in selected_nodes:
nodes.append(node.address)
url = f"{node.address}/vm/{settings.FUNCTION}/{VRF_FUNCTION_GENERATE_PATH}/{request_item_hash}"
generate_tasks.append(
asyncio.create_task(post_node_vrf(url, PublishedVRFResponseHash))
)
for executor in executors:
url = f"{executor.api_url}/{VRF_FUNCTION_GENERATE_PATH}/{request_item_hash}"
generate_tasks.append(asyncio.create_task(post_node_vrf(url, VRFResponseHash)))

vrf_generated_responses = await asyncio.gather(
*generate_tasks, return_exceptions=True
)
generate_results = dict(zip(nodes, vrf_generated_responses))
for node, result in generate_results.items():
generate_results = dict(zip(executors, vrf_generated_responses))

for executor, result in generate_results.items():
if isinstance(result, Exception):
raise ValueError(
f"Generate response not found for Node {node} on request_id {request_id}"
f"Generate response not found for executor {executor} on request_id {request_id}"
)

return generate_results


async def send_publish_requests(
vrf_generated_result: Dict[str, PublishedVRFResponseHash],
) -> Dict[str, PublishedVRFRandomBytes]:
vrf_generated_result: Dict[Executor, PublishedVRFResponseHash],
) -> Dict[Executor, PublishedVRFRandomBytes]:
publish_tasks = []
nodes: List[str] = []
executors: List[Executor] = []

for node, vrf_generated_response in vrf_generated_result.items():
nodes.append(node)
for executor, vrf_generated_response in vrf_generated_result.items():
executors.append(executor)

node_message_hash = vrf_generated_response.message_hash
url = (
f"{node}/vm/{settings.FUNCTION}"
f"/{VRF_FUNCTION_PUBLISH_PATH}/{node_message_hash}"
)
publish_tasks.append(
asyncio.create_task(post_node_vrf(url, PublishedVRFRandomBytes))
)
url = f"{executor.api_url}/{VRF_FUNCTION_PUBLISH_PATH}/{node_message_hash}"
publish_tasks.append(asyncio.create_task(post_node_vrf(url, PublishedVRFRandomBytes)))

vrf_publish_responses = await asyncio.gather(*publish_tasks, return_exceptions=True)
publish_results = dict(zip(nodes, vrf_publish_responses))
for node, result in publish_results.items():
publish_results = dict(zip(executors, vrf_publish_responses))

for executor, result in publish_results.items():
if isinstance(result, Exception):
raise ValueError(f"Publish response not found for {node}")
raise ValueError(f"Publish response not found for {executor}")

return publish_results


def generate_final_vrf(
nb_executors: int,
nonce: Nonce,
vrf_generated_result: Dict[str, PublishedVRFResponseHash],
vrf_publish_result: Dict[str, PublishedVRFRandomBytes],
vrf_generated_result: Dict[Executor, PublishedVRFResponseHash],
vrf_publish_result: Dict[Executor, PublishedVRFRandomBytes],
vrf_request: VRFRequest,
) -> VRFResponse:
nodes_responses = []
random_numbers_list = []
for node, vrf_publish_response in vrf_publish_result.items():
if isinstance(vrf_publish_response, Exception):
raise ValueError(f"Publish response not found for {node}")

for executor, vrf_publish_response in vrf_publish_result.items():
if (
vrf_generated_result[node].random_bytes_hash
vrf_generated_result[executor].random_bytes_hash
!= vrf_publish_response.random_bytes_hash
):
generated_hash = vrf_publish_response.random_bytes_hash
Expand All @@ -270,12 +263,12 @@ def generate_final_vrf(
)

node_response = CRNVRFResponse(
url=node,
url=executor.node.address,
execution_id=vrf_publish_response.execution_id,
random_number=str(vrf_publish_response.random_number),
random_bytes=vrf_publish_response.random_bytes,
random_bytes_hash=vrf_generated_result[node].random_bytes_hash,
generation_message_hash=vrf_generated_result[node].message_hash,
random_bytes_hash=vrf_generated_result[executor].random_bytes_hash,
generation_message_hash=vrf_generated_result[executor].message_hash,
publish_message_hash=vrf_publish_response.message_hash,
)
nodes_responses.append(node_response)
Expand Down
49 changes: 45 additions & 4 deletions src/aleph_vrf/models.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,62 @@
from typing import List, Optional
from typing import List, Optional, overload
from typing import TypeVar, Generic
from uuid import uuid4

import fastapi
from aleph_message.models import ItemHash, PostMessage
from pydantic import BaseModel, ValidationError, Field
from aleph_message.models.abstract import HashableModel
from pydantic import BaseModel
from pydantic import ValidationError, Field
from pydantic.generics import GenericModel

from aleph_vrf.types import Nonce, RequestId, ExecutionId


class Node(BaseModel):
hash: str
class Node(HashableModel):
address: str


class LocalNode(Node):
port: int


class ComputeResourceNode(Node):
hash: str
score: float


class Executor(HashableModel):
class Config:
exclude = {"vm_function"}

node: Node
vm_function: Optional[str]

@classmethod
@overload
def from_node(cls, node: ComputeResourceNode, vm_function: str) -> "Executor":
...

@classmethod
@overload
def from_node(cls, node: LocalNode) -> "Executor":
...

@classmethod
def from_node(cls, node, vm_function=None) -> "Executor":
if isinstance(node, ComputeResourceNode):
return cls(node=node, vm_function=vm_function)
return cls(node=node)

@property
def api_url(self) -> str:
url = self.node.address
if self.vm_function:
url += f"/vm/{self.vm_function}"

return url


class VRFRequest(BaseModel):
nb_bytes: int
nb_executors: int
Expand Down
29 changes: 18 additions & 11 deletions tests/coordinator/test_vrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from aleph_vrf.coordinator.vrf import select_random_nodes
from aleph_vrf.coordinator.vrf import select_random_executors


@pytest.fixture
Expand Down Expand Up @@ -140,39 +140,46 @@ async def test_select_random_nodes(fixture_nodes_aggregate: Dict[str, Any], mock
return_value=fixture_nodes_aggregate,
)

nodes = await select_random_nodes(3, [])
executors = await select_random_executors(3, [])
# Sanity check, avoid network accesses
network_fixture.assert_called_once()
assert len(nodes) == 3

assert len(executors) == 3

with pytest.raises(ValueError) as exception:
resource_nodes = fixture_nodes_aggregate["data"]["corechannel"][
"resource_nodes"
]
await select_random_nodes(len(resource_nodes), [])
await select_random_executors(len(resource_nodes), [])
assert (
str(exception.value)
== f"Not enough CRNs linked, only 4 available from 5 requested"
)


@pytest.mark.asyncio
async def test_select_random_nodes_with_unauthorized(fixture_nodes_aggregate: Dict[str, Any], mocker):
async def test_select_random_nodes_with_unauthorized(
fixture_nodes_aggregate: Dict[str, Any], mocker
):
network_fixture = mocker.patch(
"aleph_vrf.coordinator.vrf._get_corechannel_aggregate",
return_value=fixture_nodes_aggregate,
)

nodes = await select_random_nodes(3, ["https://aleph2.serverrg.eu"])
# Sanity check, avoid network accesses
network_fixture.assert_called_once()
assert len(nodes) == 3
assert network_fixture.called_once

executors = await select_random_executors(3, ["https://aleph2.serverrg.eu"])
assert len(executors) == 3

with pytest.raises(ValueError) as exception:
resource_nodes = fixture_nodes_aggregate["data"]["corechannel"][
"resource_nodes"
]
await select_random_nodes(len(resource_nodes) - 1, ["https://aleph2.serverrg.eu"])
_ = await select_random_executors(
len(resource_nodes) - 1, ["https://aleph2.serverrg.eu"]
)
assert (
str(exception.value)
== f"Not enough CRNs linked, only 3 available from 4 requested"
str(exception.value)
== f"Not enough CRNs linked, only 3 available from 4 requested"
)

0 comments on commit c03a823

Please sign in to comment.