Skip to content

Commit

Permalink
Internal: executor selection policy system
Browse files Browse the repository at this point in the history
Problem: for testing, we wish to switch between different ways of
selecting executors (ex: repeat tests on a specific list of CRNs, use
local executors for integration tests).

Solution: introduce the `ExecutorSelectionPolicy` class. Using this
class, the caller can parameterize the way executors are selected,
blacklisted, etc.
  • Loading branch information
odesenfans committed Sep 25, 2023
1 parent bb91c12 commit dd9d264
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 82 deletions.
89 changes: 89 additions & 0 deletions src/aleph_vrf/coordinator/executor_selection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import abc
import json
from pathlib import Path
from typing import List, Dict, Any, AsyncIterator
import random

import aiohttp
from aleph_message.models import ItemHash

from aleph_vrf.models import Executor, Node, AlephExecutor, ComputeResourceNode
from aleph_vrf.settings import settings


class ExecutorSelectionPolicy(abc.ABC):
@abc.abstractmethod
async def select_executors(self, nb_executors: int) -> List[Executor]:
...


async def _get_corechannel_aggregate() -> Dict[str, Any]:
async with aiohttp.ClientSession(settings.API_HOST) as session:
url = (
f"/api/v0/aggregates/{settings.CORECHANNEL_AGGREGATE_ADDRESS}.json?"
f"keys={settings.CORECHANNEL_AGGREGATE_KEY}"
)
async with session.get(url) as response:
if response.status != 200:
raise ValueError(f"CRN list not available")

return await response.json()


class ExecuteOnAleph(ExecutorSelectionPolicy):
def __init__(self, vm_function: ItemHash):
self.vm_function = vm_function

@staticmethod
async def _list_compute_nodes() -> AsyncIterator[ComputeResourceNode]:
content = await _get_corechannel_aggregate()

if (
not content["data"]["corechannel"]
or not content["data"]["corechannel"]["resource_nodes"]
):
raise ValueError(f"Bad CRN list format")

resource_nodes = content["data"]["corechannel"]["resource_nodes"]

for resource_node in resource_nodes:
# Filter nodes by score, with linked status
if resource_node["status"] == "linked" and resource_node["score"] > 0.9:
node_address = resource_node["address"].strip("/")
node = ComputeResourceNode(
hash=resource_node["hash"],
address=node_address,
score=resource_node["score"],
)
yield node

@staticmethod
def _get_unauthorized_nodes() -> List[str]:
unauthorized_nodes_list_path = Path(__file__).with_name(
"unauthorized_node_list.json"
)
if unauthorized_nodes_list_path.is_file():
with open(unauthorized_nodes_list_path, "rb") as fd:
return json.load(fd)

return []

async def select_executors(self, nb_executors: int) -> List[Executor]:
compute_nodes = self._list_compute_nodes()
blacklisted_nodes = self._get_unauthorized_nodes()
whitelisted_nodes = (
node
async for node in compute_nodes
if node.address not in blacklisted_nodes
)
executors = [
AlephExecutor(node=node, vm_function=self.vm_function)
async for node in whitelisted_nodes
]

if len(executors) < nb_executors:
raise ValueError(
f"Not enough CRNs linked, only {len(executors)} "
f"available from {nb_executors} requested"
)
return random.sample(executors, nb_executors)
73 changes: 6 additions & 67 deletions src/aleph_vrf/coordinator/vrf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import random
from hashlib import sha3_256
from pathlib import Path
from typing import Any, Dict, List, Type, TypeVar, Union
from typing import Dict, List, Type, TypeVar, Union
from uuid import uuid4

import aiohttp
Expand All @@ -15,9 +15,9 @@
from pydantic import BaseModel
from pydantic.json import pydantic_encoder

from aleph_vrf.coordinator.executor_selection import ExecuteOnAleph
from aleph_vrf.models import (
CRNVRFResponse,
Node,
VRFRequest,
VRFResponse,
PublishedVRFResponseHash,
Expand Down Expand Up @@ -59,73 +59,12 @@ async def post_node_vrf(url: str, model: Type[M]) -> M:
return model.parse_obj(response["data"])


async def _get_corechannel_aggregate() -> Dict[str, Any]:
async with aiohttp.ClientSession(settings.API_HOST) as session:
url = (
f"/api/v0/aggregates/{settings.CORECHANNEL_AGGREGATE_ADDRESS}.json?"
f"keys={settings.CORECHANNEL_AGGREGATE_KEY}"
)
async with session.get(url) as response:
if response.status != 200:
raise ValueError(f"CRN list not available")

return await response.json()


def _get_unauthorized_node_list() -> List[str]:
unauthorized_nodes_list_path = Path(__file__).with_name(
"unauthorized_node_list.json"
)
if unauthorized_nodes_list_path.is_file():
with open(unauthorized_nodes_list_path, "rb") as fd:
return json.load(fd)

return []


async def select_random_executors(
node_amount: int, unauthorized_nodes: List[str]
) -> List[Executor]:
node_list: List[Executor] = []

content = await _get_corechannel_aggregate()

if (
not content["data"]["corechannel"]
or not content["data"]["corechannel"]["resource_nodes"]
):
raise ValueError(f"Bad CRN list format")

resource_nodes = content["data"]["corechannel"]["resource_nodes"]

for resource_node in resource_nodes:
# Filter nodes by score, with linked status and remove unauthorized nodes
if (
resource_node["status"] == "linked"
and resource_node["score"] > 0.9
and resource_node["address"].strip("/") not in unauthorized_nodes
):
node_address = resource_node["address"].strip("/")
node = ComputeResourceNode(
hash=resource_node["hash"],
address=node_address,
score=resource_node["score"],
)
node_list.append(AlephExecutor(node=node, vm_function=settings.FUNCTION))

if len(node_list) < node_amount:
raise ValueError(
f"Not enough CRNs linked, only {len(node_list)} available from {node_amount} requested"
)

# Randomize node order
return random.sample(node_list, min(node_amount, len(node_list)))


async def generate_vrf(account: ETHAccount) -> VRFResponse:
nb_executors = settings.NB_EXECUTORS
unauthorized_nodes = _get_unauthorized_node_list()
executors = await select_random_executors(nb_executors, unauthorized_nodes)
vm_function = settings.FUNCTION

executor_selection_policy = ExecuteOnAleph(vm_function=vm_function)
executors = await executor_selection_policy.select_executors(nb_executors)
selected_nodes_json = json.dumps(
[executor.node for executor in executors], default=pydantic_encoder
).encode(encoding="utf-8")
Expand Down
5 changes: 4 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from mock_ccn import app as mock_ccn_app


def wait_for_server(host: str, port: int, nb_retries: int = 3, wait_time: int = 0.1):
def wait_for_server(host: str, port: int, nb_retries: int = 10, wait_time: int = 0.1):
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.settimeout(5)

Expand All @@ -33,6 +33,9 @@ def wait_for_server(host: str, port: int, nb_retries: int = 3, wait_time: int =
sock.connect((host, port))
except ConnectionError:
retries += 1
if retries == nb_retries:
raise

sleep(wait_time)
continue

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Any, Dict

import pytest
from aleph_message.models import ItemHash

from aleph_vrf.coordinator.vrf import select_random_executors
from aleph_vrf.coordinator.executor_selection import ExecuteOnAleph


@pytest.fixture
Expand Down Expand Up @@ -136,21 +137,21 @@ def fixture_nodes_aggregate() -> Dict[str, Any]:
@pytest.mark.asyncio
async def test_select_random_nodes(fixture_nodes_aggregate: Dict[str, Any], mocker):
network_fixture = mocker.patch(
"aleph_vrf.coordinator.vrf._get_corechannel_aggregate",
"aleph_vrf.coordinator.executor_selection._get_corechannel_aggregate",
return_value=fixture_nodes_aggregate,
)
executor_selection_policy = ExecuteOnAleph(vm_function=ItemHash("cafe" * 16))

executors = await select_random_executors(3, [])
executors = await executor_selection_policy.select_executors(3)
# Sanity check, avoid network accesses
network_fixture.assert_called_once()

assert len(executors) == 3

resource_nodes = fixture_nodes_aggregate["data"]["corechannel"]["resource_nodes"]
with pytest.raises(ValueError) as exception:
resource_nodes = fixture_nodes_aggregate["data"]["corechannel"][
"resource_nodes"
]
await select_random_executors(len(resource_nodes), [])
await executor_selection_policy.select_executors(len(resource_nodes))

assert (
str(exception.value)
== f"Not enough CRNs linked, only 4 available from 5 requested"
Expand All @@ -162,23 +163,30 @@ async def test_select_random_nodes_with_unauthorized(
fixture_nodes_aggregate: Dict[str, Any], mocker
):
network_fixture = mocker.patch(
"aleph_vrf.coordinator.vrf._get_corechannel_aggregate",
"aleph_vrf.coordinator.executor_selection._get_corechannel_aggregate",
return_value=fixture_nodes_aggregate,
)

blacklist = ["https://aleph2.serverrg.eu"]
executor_selection_policy = ExecuteOnAleph(vm_function=ItemHash("cafe" * 16))
mocker.patch.object(
executor_selection_policy, "_get_unauthorized_nodes", return_value=blacklist
)
executors = await executor_selection_policy.select_executors(3)
# Sanity check, avoid network accesses
assert network_fixture.called_once
network_fixture.assert_called_once()

executors = await select_random_executors(3, ["https://aleph2.serverrg.eu"])
assert len(executors) == 3

for blacklisted_node_address in blacklist:
assert blacklisted_node_address not in [
executor.node.address for executor in executors
]

with pytest.raises(ValueError) as exception:
resource_nodes = fixture_nodes_aggregate["data"]["corechannel"][
"resource_nodes"
]
_ = await select_random_executors(
len(resource_nodes) - 1, ["https://aleph2.serverrg.eu"]
)
_ = await executor_selection_policy.select_executors(len(resource_nodes) - 1)
assert (
str(exception.value)
== f"Not enough CRNs linked, only 3 available from 4 requested"
Expand Down

0 comments on commit dd9d264

Please sign in to comment.