Skip to content

Commit

Permalink
Allow running search cluster tests with nidx (#2627)
Browse files Browse the repository at this point in the history
  • Loading branch information
javitonino authored Nov 13, 2024
1 parent bf98cc8 commit 785e912
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 36 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion nidx/src/api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pub async fn run(settings: Settings) -> anyhow::Result<()> {
let meta = settings.metadata.clone();

let service = grpc::ApiServer::new(meta).into_service();
let server = GrpcServer::new("localhost:10000").await?;
let server = GrpcServer::new("0.0.0.0:10000").await?;
debug!("Running API at port {}", server.port()?);
server.serve(service).await?;

Expand Down
3 changes: 1 addition & 2 deletions nidx/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@
//

use nidx::{api, indexer, maintenance, searcher, Settings};
use std::collections::HashSet;
use tokio::{main, task::JoinSet};

#[main]
async fn main() -> anyhow::Result<()> {
let args: HashSet<_> = std::env::args().skip(1).collect();
let args: Vec<_> = std::env::args().skip(1).collect();

tracing_subscriber::fmt::init();

Expand Down
2 changes: 2 additions & 0 deletions nidx/src/metadata/shard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ impl Shard {
r#"SELECT kind as "kind: IndexKind", SUM(records)::bigint as "records!" FROM indexes
JOIN segments ON index_id = indexes.id
WHERE shard_id = $1
AND indexes.deleted_at IS NULL
AND segments.delete_at IS NULL
GROUP BY kind"#,
self.id
)
Expand Down
2 changes: 1 addition & 1 deletion nidx/src/searcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pub async fn run(settings: Settings) -> anyhow::Result<()> {
let searcher = SyncedSearcher::new(meta.clone(), work_dir.path());

let api = grpc::SearchServer::new(meta.clone(), searcher.index_cache());
let server = GrpcServer::new("localhost:10001").await?;
let server = GrpcServer::new("0.0.0.0:10001").await?;
let api_task = tokio::task::spawn(server.serve(api.into_service()));
let search_task = tokio::task::spawn(async move { searcher.run(storage).await });

Expand Down
2 changes: 2 additions & 0 deletions nucliadb/src/nucliadb/search/lifecycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from nucliadb.common.cluster.utils import setup_cluster, teardown_cluster
from nucliadb.common.maindb.utils import setup_driver
from nucliadb.common.nidx import start_nidx_utility
from nucliadb.ingest.utils import start_ingest, stop_ingest
from nucliadb.search import SERVICE_NAME
from nucliadb.search.predict import start_predict_engine
Expand All @@ -46,6 +47,7 @@ async def lifespan(app: FastAPI):

await setup_driver()
await setup_cluster()
await start_nidx_utility()

await start_audit_utility(SERVICE_NAME)

Expand Down
50 changes: 32 additions & 18 deletions nucliadb/tests/search/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

import asyncio
import os
from enum import Enum
from typing import AsyncIterable, Optional

Expand All @@ -27,6 +28,7 @@

from nucliadb.common.cluster.manager import KBShardManager, get_index_node
from nucliadb.common.maindb.utils import get_driver
from nucliadb.common.nidx import get_nidx_api_client
from nucliadb.ingest.cache import clear_ingest_cache
from nucliadb.search import API_PREFIX
from nucliadb.search.predict import DummyPredictEngine
Expand Down Expand Up @@ -203,25 +205,37 @@ async def wait_for_shard(knowledgebox_ingest: str, count: int) -> str:
await txn.abort()

checks: dict[str, bool] = {}
for replica in shard.replicas:
if replica.shard.id not in checks:
checks[replica.shard.id] = False

for i in range(30):
if os.environ.get("NIDX_ENABLED"):
checks[""] = False
nidx_api = get_nidx_api_client()
req = GetShardRequest()
req.shard_id.id = shard.shard
for i in range(30):
count_shard: Shard = await nidx_api.GetShard(req) # type: ignore
if count_shard.fields >= count:
checks[""] = True
break
await asyncio.sleep(1)
else:
for replica in shard.replicas:
node_obj = get_index_node(replica.node)
if node_obj is not None:
req = GetShardRequest()
req.shard_id.id = replica.shard.id
count_shard: Shard = await node_obj.reader.GetShard(req) # type: ignore
if count_shard.fields >= count:
checks[replica.shard.id] = True
else:
checks[replica.shard.id] = False

if all(checks.values()):
break
await asyncio.sleep(1)
if replica.shard.id not in checks:
checks[replica.shard.id] = False

for i in range(30):
for replica in shard.replicas:
node_obj = get_index_node(replica.node)
if node_obj is not None:
req = GetShardRequest()
req.shard_id.id = replica.shard.id
count_shard: Shard = await node_obj.reader.GetShard(req) # type: ignore
if count_shard.fields >= count:
checks[replica.shard.id] = True
else:
checks[replica.shard.id] = False

if all(checks.values()):
break
await asyncio.sleep(1)

assert all(checks.values())
return knowledgebox_ingest
3 changes: 0 additions & 3 deletions nucliadb/tests/search/integration/api/v1/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
RUNNING_IN_GH_ACTIONS = os.environ.get("CI", "").lower() == "true"


@pytest.mark.flaky(reruns=5)
@pytest.mark.asyncio
async def test_multiple_fuzzy_search_resource_all(
search_api: Callable[..., AsyncClient], multiple_search_resource: str
Expand All @@ -65,7 +64,6 @@ async def test_multiple_fuzzy_search_resource_all(
)


@pytest.mark.flaky(reruns=5)
@pytest.mark.asyncio
async def test_multiple_search_resource_all(
search_api: Callable[..., AsyncClient], multiple_search_resource: str
Expand Down Expand Up @@ -121,7 +119,6 @@ async def test_multiple_search_resource_all(


@pytest.mark.asyncio
@pytest.mark.flaky(reruns=3)
async def test_search_resource_all(
search_api: Callable[..., AsyncClient],
test_search_resource: str,
Expand Down
115 changes: 106 additions & 9 deletions nucliadb/tests/search/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#

import dataclasses
import logging
import os
Expand All @@ -25,13 +26,14 @@

import backoff
import docker # type: ignore
import nats
import pytest
from grpc import insecure_channel
from grpc_health.v1 import health_pb2_grpc
from grpc_health.v1.health_pb2 import HealthCheckRequest
from nats.js.api import ConsumerConfig
from pytest_docker_fixtures import images # type: ignore
from pytest_docker_fixtures.containers._base import BaseImage # type: ignore
from pytest_lazy_fixtures import lazy_fixture

from nucliadb.common.cluster.settings import settings as cluster_settings
from nucliadb_protos.nodewriter_pb2 import EmptyQuery, ShardId
Expand Down Expand Up @@ -110,6 +112,29 @@
},
}

images.settings["nidx"] = {
"image": "nidx",
"version": "latest",
"env": {},
"options": {
# A few indexers on purpose for faster indexing
"command": [
"nidx",
"api",
"searcher",
"indexer",
"indexer",
"indexer",
"indexer",
"scheduler",
"worker",
],
"ports": {"10000": ("0.0.0.0", 0), "10001": ("0.0.0.0", 0)},
"publish_all_ports": False,
"platform": "linux/amd64",
},
}


def get_container_host(container_obj):
return container_obj.attrs["NetworkSettings"]["IPAddress"]
Expand Down Expand Up @@ -190,6 +215,10 @@ def check(self):
return False


class NidxImage(BaseImage):
name = "nidx"


nucliadb_node_1_reader = nucliadbNodeReader()
nucliadb_node_1_writer = nucliadbNodeWriter()
nucliadb_node_1_sidecar = nucliadbNodeSidecar()
Expand Down Expand Up @@ -377,20 +406,49 @@ def s3_node_storage(s3):
return NodeS3Storage(server=s3)


def lazy_load_storage_backend():
@pytest.fixture(scope="session")
def node_storage(request):
backend = get_testing_storage_backend()
if backend == "gcs":
return [lazy_fixture.lf("gcs_node_storage")]
return request.getfixturevalue("gcs_node_storage")
elif backend == "s3":
return [lazy_fixture.lf("s3_node_storage")]
return request.getfixturevalue("s3_node_storage")
else:
print(f"Unknown storage backend {backend}, using gcs")
return [lazy_fixture.lf("gcs_node_storage")]
return request.getfixturevalue("gcs_node_storage")


@pytest.fixture(scope="session", params=lazy_load_storage_backend())
def node_storage(request):
return request.param
@pytest.fixture(scope="session")
def gcs_nidx_storage(gcs):
return {
"INDEXER__OBJECT_STORE": "gcs",
"INDEXER__BUCKET": "indexing",
"INDEXER__ENDPOINT": gcs,
"STORAGE__OBJECT_STORE": "gcs",
"STORAGE__ENDPOINT": gcs,
"STORAGE__BUCKET": "nidx",
}


@pytest.fixture(scope="session")
def s3_nidx_storage(s3):
return {
"INDEXER__OBJECT_STORE": "s3",
"INDEXER__BUCKET": "indexing",
"INDEXER__ENDPOINT": s3,
"STORAGE__OBJECT_STORE": "s3",
"STORAGE__ENDPOINT": s3,
"STORAGE__BUCKET": "nidx",
}


@pytest.fixture(scope="session")
def nidx_storage(request):
backend = get_testing_storage_backend()
if backend == "gcs":
return request.getfixturevalue("gcs_nidx_storage")
elif backend == "s3":
return request.getfixturevalue("s3_nidx_storage")


@pytest.fixture(scope="session", autouse=False)
Expand All @@ -406,8 +464,47 @@ def _node(natsd: str, node_storage):
nr.stop()


@pytest.fixture(scope="session")
async def _nidx(natsd, nidx_storage, pg):
if not os.environ.get("NIDX_ENABLED"):
yield
return

# Create needed NATS stream/consumer
nc = await nats.connect(servers=[natsd])
js = nc.jetstream()
await js.add_stream(name="nidx", subjects=["nidx"])
await js.add_consumer(stream="nidx", config=ConsumerConfig(name="nidx"))
await nc.drain()
await nc.close()

# Run nidx
images.settings["nidx"]["env"] = {
"RUST_LOG": "info",
"METADATA__DATABASE_URL": f"postgresql://postgres:[email protected]:{pg[1]}/postgres",
"INDEXER__NATS_SERVER": natsd.replace("localhost", "172.17.0.1"),
**nidx_storage,
}
image = NidxImage()
image.run()

api_port = image.get_port(10000)
searcher_port = image.get_port(10001)

# Configure settings
from nucliadb_utils.settings import indexing_settings

cluster_settings.nidx_api_address = f"localhost:{api_port}"
cluster_settings.nidx_searcher_address = f"localhost:{searcher_port}"
indexing_settings.index_nidx_subject = "nidx"

yield

image.stop()


@pytest.fixture(scope="function")
def node(_node, request):
def node(_nidx, _node, request):
# clean up all shard data before each test
channel1 = insecure_channel(f"{_node['writer1']['host']}:{_node['writer1']['port']}")
channel2 = insecure_channel(f"{_node['writer2']['host']}:{_node['writer2']['port']}")
Expand Down

0 comments on commit 785e912

Please sign in to comment.