Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
lferran committed Dec 1, 2023
1 parent af31163 commit 575e438
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,27 @@
make_kbid_request,
print_errors,
)
from nucliadb_performance.utils.saved_requests import load_all_saved_requests
from nucliadb_performance.utils.vectors import compute_vector

FIND_WEIGHT = 1
SEARCH_WEIGHT = 1
SUGGEST_WEIGHT = 1


def precompute_vectors():
"""
Precompute vectors for all saved requests at the beginning of the test.
"""
saved_requests = load_all_saved_requests(settings.saved_requests_file)
for request_set in saved_requests.sets.values():
for saved_request in request_set.requests:
if saved_request.request.payload is None:
continue
if "query" not in saved_request.request.payload:
continue
compute_vector(saved_request.request.payload["query"])


@cache
def get_test_kb():
Expand All @@ -27,9 +46,10 @@ def get_test_kb():
@global_setup()
def init_test(args):
get_test_kb()
precompute_vectors()


@scenario(weight=1)
@scenario(weight=FIND_WEIGHT)
async def test_find(session):
kbid, slug = get_test_kb()
request = get_request(slug, endpoint="find", with_tags=settings.request_tags)
Expand All @@ -41,20 +61,38 @@ async def test_find(session):
request.url.format(kbid=kbid),
json=request.payload,
)
# TODO: assert that there are results!
assert len(resp["resources"]) > 0


@scenario(weight=SEARCH_WEIGHT)
async def test_search(session):
kbid, slug = get_test_kb()
request = get_request(slug, endpoint="search", with_tags=settings.request_tags)
request.payload["vector"] = compute_vector(request.payload["query"])
resp = await make_kbid_request(
session,
kbid,
request.method.upper(),
request.url.format(kbid=kbid),
json=request.payload,
)
assert len(resp["resources"]) > 0


@scenario(weight=2)
@scenario(weight=SUGGEST_WEIGHT)
async def test_suggest(session):
kbid, slug = get_test_kb()
request = get_request(slug, endpoint="suggest")
await make_kbid_request(
resp = await make_kbid_request(
session,
kbid,
request.method.upper(),
request.url.format(kbid=kbid),
params=request.params,
)
assert (
len(resp["paragraphs"]["results"]) > 0 or len(resp["entities"]["entities"]) > 0
)


@global_teardown()
Expand Down
36 changes: 19 additions & 17 deletions nucliadb_performance/nucliadb_performance/utils/saved_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,22 @@ class SavedRequests(BaseModel):
def load_saved_request(
saved_requests_file: str, kbid_or_slug: str, endpoint: str, with_tags=None
) -> list[Request]:
try:
saved_requests = SavedRequests.parse_file(saved_requests_file)
kb_requests = []
for rs in saved_requests.sets.values():
if kbid_or_slug not in rs.kbs:
continue
kb_requests.extend([r for r in rs.requests if r.endpoint == endpoint])
if with_tags is None:
return [kb_req.request for kb_req in kb_requests]
else:
return [
kb_req.request
for kb_req in kb_requests
if set(with_tags).issubset(kb_req.tags)
]
except (FileNotFoundError, KeyError):
return []
saved_requests = load_all_saved_requests(saved_requests_file)
kb_requests = []
for rs in saved_requests.sets.values():
if kbid_or_slug not in rs.kbs:
continue
kb_requests.extend([r for r in rs.requests if r.endpoint == endpoint])
if with_tags is None:
return [kb_req.request for kb_req in kb_requests]
else:
return [
kb_req.request
for kb_req in kb_requests
if set(with_tags).issubset(kb_req.tags)
]


@cache
def load_all_saved_requests(saved_requests_file: str) -> SavedRequests:
return SavedRequests.parse_file(saved_requests_file)
Loading

0 comments on commit 575e438

Please sign in to comment.