diff --git a/src/aleph_vrf/coordinator/executor_selection.py b/src/aleph_vrf/coordinator/executor_selection.py new file mode 100644 index 0000000..d93b956 --- /dev/null +++ b/src/aleph_vrf/coordinator/executor_selection.py @@ -0,0 +1,89 @@ +import abc +import json +from pathlib import Path +from typing import List, Dict, Any, AsyncIterator +import random + +import aiohttp +from aleph_message.models import ItemHash + +from aleph_vrf.models import Executor, Node, AlephExecutor, ComputeResourceNode +from aleph_vrf.settings import settings + + +class ExecutorSelectionPolicy(abc.ABC): + @abc.abstractmethod + async def select_executors(self, nb_executors: int) -> List[Executor]: + ... + + +async def _get_corechannel_aggregate() -> Dict[str, Any]: + async with aiohttp.ClientSession(settings.API_HOST) as session: + url = ( + f"/api/v0/aggregates/{settings.CORECHANNEL_AGGREGATE_ADDRESS}.json?" + f"keys={settings.CORECHANNEL_AGGREGATE_KEY}" + ) + async with session.get(url) as response: + if response.status != 200: + raise ValueError(f"CRN list not available") + + return await response.json() + + +class ExecuteOnAleph(ExecutorSelectionPolicy): + def __init__(self, vm_function: ItemHash): + self.vm_function = vm_function + + @staticmethod + async def _list_compute_nodes() -> AsyncIterator[ComputeResourceNode]: + content = await _get_corechannel_aggregate() + + if ( + not content["data"]["corechannel"] + or not content["data"]["corechannel"]["resource_nodes"] + ): + raise ValueError(f"Bad CRN list format") + + resource_nodes = content["data"]["corechannel"]["resource_nodes"] + + for resource_node in resource_nodes: + # Filter nodes by score, with linked status + if resource_node["status"] == "linked" and resource_node["score"] > 0.9: + node_address = resource_node["address"].strip("/") + node = ComputeResourceNode( + hash=resource_node["hash"], + address=node_address, + score=resource_node["score"], + ) + yield node + + @staticmethod + def _get_unauthorized_nodes() -> List[str]: + unauthorized_nodes_list_path = Path(__file__).with_name( + "unauthorized_node_list.json" + ) + if unauthorized_nodes_list_path.is_file(): + with open(unauthorized_nodes_list_path, "rb") as fd: + return json.load(fd) + + return [] + + async def select_executors(self, nb_executors: int) -> List[Executor]: + compute_nodes = self._list_compute_nodes() + blacklisted_nodes = self._get_unauthorized_nodes() + whitelisted_nodes = ( + node + async for node in compute_nodes + if node.address not in blacklisted_nodes + ) + executors = [ + AlephExecutor(node=node, vm_function=self.vm_function) + async for node in whitelisted_nodes + ] + + if len(executors) < nb_executors: + raise ValueError( + f"Not enough CRNs linked, only {len(executors)} " + f"available from {nb_executors} requested" + ) + return random.sample(executors, nb_executors) diff --git a/src/aleph_vrf/coordinator/vrf.py b/src/aleph_vrf/coordinator/vrf.py index 08f2500..373a35b 100644 --- a/src/aleph_vrf/coordinator/vrf.py +++ b/src/aleph_vrf/coordinator/vrf.py @@ -4,7 +4,7 @@ import random from hashlib import sha3_256 from pathlib import Path -from typing import Any, Dict, List, Type, TypeVar, Union +from typing import Dict, List, Type, TypeVar, Union from uuid import uuid4 import aiohttp @@ -15,9 +15,9 @@ from pydantic import BaseModel from pydantic.json import pydantic_encoder +from aleph_vrf.coordinator.executor_selection import ExecuteOnAleph from aleph_vrf.models import ( CRNVRFResponse, - Node, VRFRequest, VRFResponse, PublishedVRFResponseHash, @@ -59,73 +59,12 @@ async def post_node_vrf(url: str, model: Type[M]) -> M: return model.parse_obj(response["data"]) -async def _get_corechannel_aggregate() -> Dict[str, Any]: - async with aiohttp.ClientSession(settings.API_HOST) as session: - url = ( - f"/api/v0/aggregates/{settings.CORECHANNEL_AGGREGATE_ADDRESS}.json?" - f"keys={settings.CORECHANNEL_AGGREGATE_KEY}" - ) - async with session.get(url) as response: - if response.status != 200: - raise ValueError(f"CRN list not available") - - return await response.json() - - -def _get_unauthorized_node_list() -> List[str]: - unauthorized_nodes_list_path = Path(__file__).with_name( - "unauthorized_node_list.json" - ) - if unauthorized_nodes_list_path.is_file(): - with open(unauthorized_nodes_list_path, "rb") as fd: - return json.load(fd) - - return [] - - -async def select_random_executors( - node_amount: int, unauthorized_nodes: List[str] -) -> List[Executor]: - node_list: List[Executor] = [] - - content = await _get_corechannel_aggregate() - - if ( - not content["data"]["corechannel"] - or not content["data"]["corechannel"]["resource_nodes"] - ): - raise ValueError(f"Bad CRN list format") - - resource_nodes = content["data"]["corechannel"]["resource_nodes"] - - for resource_node in resource_nodes: - # Filter nodes by score, with linked status and remove unauthorized nodes - if ( - resource_node["status"] == "linked" - and resource_node["score"] > 0.9 - and resource_node["address"].strip("/") not in unauthorized_nodes - ): - node_address = resource_node["address"].strip("/") - node = ComputeResourceNode( - hash=resource_node["hash"], - address=node_address, - score=resource_node["score"], - ) - node_list.append(AlephExecutor(node=node, vm_function=settings.FUNCTION)) - - if len(node_list) < node_amount: - raise ValueError( - f"Not enough CRNs linked, only {len(node_list)} available from {node_amount} requested" - ) - - # Randomize node order - return random.sample(node_list, min(node_amount, len(node_list))) - - async def generate_vrf(account: ETHAccount) -> VRFResponse: nb_executors = settings.NB_EXECUTORS - unauthorized_nodes = _get_unauthorized_node_list() - executors = await select_random_executors(nb_executors, unauthorized_nodes) + vm_function = settings.FUNCTION + + executor_selection_policy = ExecuteOnAleph(vm_function=vm_function) + executors = await executor_selection_policy.select_executors(nb_executors) selected_nodes_json = json.dumps( [executor.node for executor in executors], default=pydantic_encoder ).encode(encoding="utf-8") diff --git a/tests/conftest.py b/tests/conftest.py index a5cb210..687b823 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,7 @@ from mock_ccn import app as mock_ccn_app -def wait_for_server(host: str, port: int, nb_retries: int = 3, wait_time: int = 0.1): +def wait_for_server(host: str, port: int, nb_retries: int = 10, wait_time: int = 0.1): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.settimeout(5) @@ -33,6 +33,9 @@ def wait_for_server(host: str, port: int, nb_retries: int = 3, wait_time: int = sock.connect((host, port)) except ConnectionError: retries += 1 + if retries == nb_retries: + raise + sleep(wait_time) continue diff --git a/tests/coordinator/test_vrf.py b/tests/coordinator/test_executor_selection.py similarity index 86% rename from tests/coordinator/test_vrf.py rename to tests/coordinator/test_executor_selection.py index 7cc0a4a..153b84b 100644 --- a/tests/coordinator/test_vrf.py +++ b/tests/coordinator/test_executor_selection.py @@ -1,8 +1,9 @@ from typing import Any, Dict import pytest +from aleph_message.models import ItemHash -from aleph_vrf.coordinator.vrf import select_random_executors +from aleph_vrf.coordinator.executor_selection import ExecuteOnAleph @pytest.fixture @@ -136,21 +137,21 @@ def fixture_nodes_aggregate() -> Dict[str, Any]: @pytest.mark.asyncio async def test_select_random_nodes(fixture_nodes_aggregate: Dict[str, Any], mocker): network_fixture = mocker.patch( - "aleph_vrf.coordinator.vrf._get_corechannel_aggregate", + "aleph_vrf.coordinator.executor_selection._get_corechannel_aggregate", return_value=fixture_nodes_aggregate, ) + executor_selection_policy = ExecuteOnAleph(vm_function=ItemHash("cafe" * 16)) - executors = await select_random_executors(3, []) + executors = await executor_selection_policy.select_executors(3) # Sanity check, avoid network accesses network_fixture.assert_called_once() assert len(executors) == 3 + resource_nodes = fixture_nodes_aggregate["data"]["corechannel"]["resource_nodes"] with pytest.raises(ValueError) as exception: - resource_nodes = fixture_nodes_aggregate["data"]["corechannel"][ - "resource_nodes" - ] - await select_random_executors(len(resource_nodes), []) + await executor_selection_policy.select_executors(len(resource_nodes)) + assert ( str(exception.value) == f"Not enough CRNs linked, only 4 available from 5 requested" @@ -162,23 +163,30 @@ async def test_select_random_nodes_with_unauthorized( fixture_nodes_aggregate: Dict[str, Any], mocker ): network_fixture = mocker.patch( - "aleph_vrf.coordinator.vrf._get_corechannel_aggregate", + "aleph_vrf.coordinator.executor_selection._get_corechannel_aggregate", return_value=fixture_nodes_aggregate, ) - + blacklist = ["https://aleph2.serverrg.eu"] + executor_selection_policy = ExecuteOnAleph(vm_function=ItemHash("cafe" * 16)) + mocker.patch.object( + executor_selection_policy, "_get_unauthorized_nodes", return_value=blacklist + ) + executors = await executor_selection_policy.select_executors(3) # Sanity check, avoid network accesses - assert network_fixture.called_once + network_fixture.assert_called_once() - executors = await select_random_executors(3, ["https://aleph2.serverrg.eu"]) assert len(executors) == 3 + for blacklisted_node_address in blacklist: + assert blacklisted_node_address not in [ + executor.node.address for executor in executors + ] + with pytest.raises(ValueError) as exception: resource_nodes = fixture_nodes_aggregate["data"]["corechannel"][ "resource_nodes" ] - _ = await select_random_executors( - len(resource_nodes) - 1, ["https://aleph2.serverrg.eu"] - ) + _ = await executor_selection_policy.select_executors(len(resource_nodes) - 1) assert ( str(exception.value) == f"Not enough CRNs linked, only 3 available from 4 requested"