Skip to content

Commit

Permalink
Return NER family in entity suggestions (#1729)
Browse files Browse the repository at this point in the history
Return NER family in entity suggestions
  • Loading branch information
javitonino authored Jan 16, 2024
1 parent d487a11 commit 1330a47
Show file tree
Hide file tree
Showing 24 changed files with 925 additions and 909 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ target
.egg-info
*.so
venv/
.venv/
nucliadb/.mypy_cache
lib
share
Expand Down
15 changes: 9 additions & 6 deletions nucliadb/nucliadb/search/search/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#
import datetime
import math
from typing import Any, Optional, Union
from typing import Any, Optional, Set, Union

from nucliadb_protos.nodereader_pb2 import (
DocumentResult,
Expand Down Expand Up @@ -53,6 +53,7 @@
Paragraph,
Paragraphs,
RelatedEntities,
RelatedEntity,
RelationDirection,
RelationNodeTypeMap,
Relations,
Expand Down Expand Up @@ -594,13 +595,15 @@ async def merge_paragraphs_results(
async def merge_suggest_entities_results(
suggest_responses: list[SuggestResponse],
) -> RelatedEntities:
merge = RelatedEntities(entities=[], total=0)

unique_entities: Set[RelatedEntity] = set()
for response in suggest_responses:
merge.entities.extend(response.entities.entities)
merge.total += response.entities.total
response_entities = (
RelatedEntity(family=e.subtype, value=e.value)
for e in response.entity_results.nodes
)
unique_entities.update(response_entities)

return merge
return RelatedEntities(entities=list(unique_entities), total=len(unique_entities))


async def merge_suggest_results(
Expand Down
2 changes: 1 addition & 1 deletion nucliadb/nucliadb/tests/integration/test_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ async def test_entities_indexing(
assert resp.status_code == 200
body = resp.json()

entities = set(body["entities"]["entities"])
entities = set((e["value"] for e in body["entities"]["entities"]))
# BUG? why is "domestic cat" not appearing in the results?
assert entities == {"dog", "dolphin"}
# assert entities == {"dog", "domestic cat", "dolphin"}
23 changes: 13 additions & 10 deletions nucliadb/nucliadb/tests/integration/test_suggest.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,16 +212,19 @@ async def test_suggest_related_entities(
)
assert resp.status_code == 201

def assert_expected_entities(body, expected):
assert set((e["value"] for e in body["entities"]["entities"])) == expected

# Test simple suggestions
resp = await nucliadb_reader.get(f"/kb/{knowledgebox}/suggest?query=Ann")
assert resp.status_code == 200
body = resp.json()
assert set(body["entities"]["entities"]) == {"Anna", "Anthony"}
assert_expected_entities(body, {"Anna", "Anthony"})

resp = await nucliadb_reader.get(f"/kb/{knowledgebox}/suggest?query=joh")
assert resp.status_code == 200
body = resp.json()
assert set(body["entities"]["entities"]) == {"John"}
assert_expected_entities(body, {"John"})

resp = await nucliadb_reader.get(f"/kb/{knowledgebox}/suggest?query=xxxxx")
assert resp.status_code == 200
Expand All @@ -233,33 +236,33 @@ async def test_suggest_related_entities(
resp = await nucliadb_reader.get(f"/kb/{knowledgebox}/suggest?query=bar")
assert resp.status_code == 200
body = resp.json()
assert set(body["entities"]["entities"]) == {"Barcelona", "Bárcenas"}
assert_expected_entities(body, {"Barcelona", "Bárcenas"})

resp = await nucliadb_reader.get(f"/kb/{knowledgebox}/suggest?query=Bar")
assert resp.status_code == 200
body = resp.json()
assert set(body["entities"]["entities"]) == {"Barcelona", "Bárcenas"}
assert_expected_entities(body, {"Barcelona", "Bárcenas"})

resp = await nucliadb_reader.get(f"/kb/{knowledgebox}/suggest?query=BAR")
assert resp.status_code == 200
body = resp.json()
assert set(body["entities"]["entities"]) == {"Barcelona", "Bárcenas"}
assert_expected_entities(body, {"Barcelona", "Bárcenas"})

resp = await nucliadb_reader.get(f"/kb/{knowledgebox}/suggest?query=BÄR")
assert resp.status_code == 200
body = resp.json()
assert set(body["entities"]["entities"]) == {"Barcelona", "Bárcenas"}
assert_expected_entities(body, {"Barcelona", "Bárcenas"})

resp = await nucliadb_reader.get(f"/kb/{knowledgebox}/suggest?query=BáR")
assert resp.status_code == 200
body = resp.json()
assert set(body["entities"]["entities"]) == {"Barcelona", "Bárcenas"}
assert_expected_entities(body, {"Barcelona", "Bárcenas"})

# Test multiple word suggest and ordering
resp = await nucliadb_reader.get(f"/kb/{knowledgebox}/suggest?query=Solomon+Is")
assert resp.status_code == 200
body = resp.json()
assert set(body["entities"]["entities"]) == {"Solomon Islands", "Israel"}
assert_expected_entities(body, {"Solomon Islands", "Israel"})


@pytest.mark.asyncio
Expand Down Expand Up @@ -350,8 +353,8 @@ def assert_expected_paragraphs(response):

def assert_expected_entities(response):
expected = {"Anna", "Anthony"}
assert response["entities"]["total"] == len(expected)
assert set(response["entities"]["entities"]) == expected
assert len(response["entities"]) == len(expected)
assert set((e["value"] for e in response["entities"]["entities"])) == expected

resp = await nucliadb_reader.get(
f"/kb/{knowledgebox}/suggest",
Expand Down
7 changes: 6 additions & 1 deletion nucliadb_models/nucliadb_models/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,14 @@ class Relations(BaseModel):
# graph: List[RelationPath]


class RelatedEntity(BaseModel, frozen=True):
family: str
value: str


class RelatedEntities(BaseModel):
total: int = 0
entities: List[str] = []
entities: List[RelatedEntity] = []


class ResourceSearchResults(JsonBaseModel):
Expand Down
17 changes: 7 additions & 10 deletions nucliadb_node/src/shards/shard_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use nucliadb_core::protos::shard_created::{
};
use nucliadb_core::protos::{
DocumentSearchRequest, DocumentSearchResponse, EdgeList, GetShardRequest,
ParagraphSearchRequest, ParagraphSearchResponse, RelatedEntities, RelationPrefixSearchRequest,
ParagraphSearchRequest, ParagraphSearchResponse, RelationPrefixSearchRequest,
RelationSearchRequest, RelationSearchResponse, SearchRequest, SearchResponse, Shard, ShardFile,
ShardFileChunk, ShardFileList, StreamRequest, SuggestFeatures, SuggestRequest, SuggestResponse,
TypeList, VectorSearchRequest, VectorSearchResponse,
Expand All @@ -38,8 +38,9 @@ use nucliadb_core::query_planner::QueryPlan;
use nucliadb_core::thread::*;
use nucliadb_core::tracing::{self, *};
use nucliadb_procs::measure;
use nucliadb_protos::nodereader::RelationNodeFilter;
use nucliadb_protos::nodereader::{RelationNodeFilter, RelationPrefixSearchResponse};
use nucliadb_protos::utils::relation_node::NodeType;
use nucliadb_relations2::reader::HashedRelationNode;

use crate::disk_structure::*;
use crate::shards::metadata::ShardMetadata;
Expand Down Expand Up @@ -362,13 +363,12 @@ impl ShardReader {
.into_iter()
.flatten() // unwrap errors and continue with successful results
.flat_map(|response| response.prefix)
.flat_map(|prefix_response| prefix_response.nodes.into_iter())
.map(|node| node.value);
.flat_map(|prefix_response| prefix_response.nodes.into_iter());

// remove duplicate entities
let mut seen = HashSet::new();
let mut seen: HashSet<HashedRelationNode> = HashSet::new();
let mut ent_result = entities.collect::<Vec<_>>();
ent_result.retain(|e| seen.insert(e.clone()));
ent_result.retain(|e| seen.insert(e.clone().into()));

ent_result
};
Expand Down Expand Up @@ -403,10 +403,7 @@ impl ShardReader {
};

if let Some(entities) = entities {
response.entities = Some(RelatedEntities {
total: entities.len() as u32,
entities,
})
response.entity_results = Some(RelationPrefixSearchResponse { nodes: entities });
}

Ok(response)
Expand Down
23 changes: 12 additions & 11 deletions nucliadb_node/tests/test_suggest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,10 +203,10 @@ async fn test_suggest_entities(

// multiple words and result ordering
let response = suggest_entities(&mut reader, &shard.id, "Solomon Isa").await;
assert!(response.entities.is_some());
assert_eq!(response.entities.as_ref().unwrap().total, 2);
assert!(response.entities.as_ref().unwrap().entities[0] == *"Solomon Islands");
assert!(response.entities.as_ref().unwrap().entities[1] == *"Israel");
assert!(response.entity_results.is_some());
assert_eq!(response.entity_results.as_ref().unwrap().nodes.len(), 2);
assert!(response.entity_results.as_ref().unwrap().nodes[0].value == *"Solomon Islands");
assert!(response.entity_results.as_ref().unwrap().nodes[1].value == *"Israel");

// Does not find resources by UUID prefix
let pap_uuid = &shard.resources["pap"];
Expand Down Expand Up @@ -238,7 +238,7 @@ async fn test_suggest_features(
let mut reader = node.reader_client();

let response = suggest_paragraphs(&mut reader, &shard.id, "ann").await;
assert!(response.entities.is_none());
assert!(response.entity_results.is_none());
expect_paragraphs(
&response,
&[(&shard.resources["little prince"], "/a/summary")],
Expand Down Expand Up @@ -358,20 +358,21 @@ async fn suggest_entities(
}

fn expect_entities(response: &SuggestResponse, expected: &[&str]) {
assert!(response.entities.is_some());
assert!(response.entity_results.is_some());
assert_eq!(
response.entities.as_ref().unwrap().total as usize,
response.entity_results.as_ref().unwrap().nodes.len(),
expected.len(),
"Response entities don't match expected ones: {:?} != {:?}",
response.entities,
response.entity_results,
expected,
);
for entity in expected {
assert!(response
.entities
.entity_results
.as_ref()
.unwrap()
.entities
.contains(&entity.to_string()));
.nodes
.iter()
.any(|e| &e.value == entity));
}
}
2 changes: 1 addition & 1 deletion nucliadb_protos/nodereader.proto
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ message SuggestResponse {
repeated string ematches = 4;

// Entities related with the query
RelatedEntities entities = 5;
RelationPrefixSearchResponse entity_results = 6;
}

message SearchResponse {
Expand Down
40 changes: 20 additions & 20 deletions nucliadb_protos/python/nucliadb_protos/audit_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

3 comments on commit 1330a47

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 1330a47 Previous: 5a633b0 Ratio
nucliadb/search/tests/unit/search/test_fetch.py::test_highligh_error 12827.701189266023 iter/sec (stddev: 8.198057301654928e-7) 12745.686329086004 iter/sec (stddev: 1.7317806991721728e-7) 0.99

This comment was automatically generated by workflow using github-action-benchmark.

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 1330a47 Previous: 5a633b0 Ratio
nucliadb/search/tests/unit/search/test_fetch.py::test_highligh_error 12796.517241657673 iter/sec (stddev: 2.1287412265489572e-7) 12745.686329086004 iter/sec (stddev: 1.7317806991721728e-7) 1.00

This comment was automatically generated by workflow using github-action-benchmark.

@github-actions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Benchmark

Benchmark suite Current: 1330a47 Previous: 5a633b0 Ratio
nucliadb/search/tests/unit/search/test_fetch.py::test_highligh_error 13172.524737154905 iter/sec (stddev: 2.714951867974406e-7) 12745.686329086004 iter/sec (stddev: 1.7317806991721728e-7) 0.97

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.