Skip to content

Commit

Permalink
Merge branch 'main' into predict-is-compatible-with-vectorsets
Browse files Browse the repository at this point in the history
  • Loading branch information
javitonino authored Aug 1, 2024
2 parents 5b7379b + 95f0d93 commit fcc7435
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 93 deletions.
48 changes: 33 additions & 15 deletions nucliadb/src/nucliadb/common/datamanagers/kb.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,21 +108,39 @@ async def get_matryoshka_vector_dimension(
vectorset_id: Optional[str] = None,
) -> Optional[int]:
"""Return vector dimension for matryoshka models"""
model_metadata = await get_model_metadata(txn, kbid=kbid)
dimension = None
if len(model_metadata.matryoshka_dimensions) > 0 and model_metadata.vector_dimension:
if model_metadata.vector_dimension in model_metadata.matryoshka_dimensions:
dimension = model_metadata.vector_dimension
else:
logger.error(
"KB has an invalid matryoshka dimension!",
extra={
"kbid": kbid,
"vector_dimension": model_metadata.vector_dimension,
"matryoshka_dimensions": model_metadata.matryoshka_dimensions,
},
)
return dimension
from . import vectorsets

async for _, vs in vectorsets.iter(txn, kbid=kbid):
if len(vs.matryoshka_dimensions) > 0 and vs.vectorset_index_config.vector_dimension:
if vs.vectorset_index_config.vector_dimension in vs.matryoshka_dimensions:
return vs.vectorset_index_config.vector_dimension
else:
logger.error(
"KB has an invalid matryoshka dimension!",
extra={
"kbid": kbid,
"vector_dimension": vs.vectorset_index_config.vector_dimension,
"matryoshka_dimensions": vs.matryoshka_dimensions,
},
)
return None
else:
# fallback for KBs that don't have vectorset
model_metadata = await get_model_metadata(txn, kbid=kbid)
dimension = None
if len(model_metadata.matryoshka_dimensions) > 0 and model_metadata.vector_dimension:
if model_metadata.vector_dimension in model_metadata.matryoshka_dimensions:
dimension = model_metadata.vector_dimension
else:
logger.error(
"KB has an invalid matryoshka dimension!",
extra={
"kbid": kbid,
"vector_dimension": model_metadata.vector_dimension,
"matryoshka_dimensions": model_metadata.matryoshka_dimensions,
},
)
return dimension


async def get_external_index_provider_metadata(
Expand Down
17 changes: 15 additions & 2 deletions nucliadb_core/src/vectors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,22 +107,29 @@ pub trait VectorWriter: std::fmt::Debug + Send + Sync {
pub struct ResourceWrapper<'a> {
resource: &'a noderesources::Resource,
vectorset: Option<String>,
fallback_to_default_vectorset: bool,
}

impl<'a> From<&'a noderesources::Resource> for ResourceWrapper<'a> {
fn from(value: &'a noderesources::Resource) -> Self {
Self {
resource: value,
vectorset: None,
fallback_to_default_vectorset: false,
}
}
}

impl<'a> ResourceWrapper<'a> {
pub fn new_vectorset_resource(resource: &'a noderesources::Resource, vectorset: &str) -> Self {
pub fn new_vectorset_resource(
resource: &'a noderesources::Resource,
vectorset: &str,
fallback_to_default_vectorset: bool,
) -> Self {
Self {
resource,
vectorset: Some(vectorset.to_string()),
fallback_to_default_vectorset,
}
}

Expand All @@ -136,7 +143,13 @@ impl<'a> ResourceWrapper<'a> {
let sentences = if let Some(vectorset) = &self.vectorset {
// indexing a vectorset, we should return only paragraphs from this vectorset.
// If vectorset is not found, we'll skip this paragraph
paragraph.vectorsets_sentences.get(vectorset).map(|x| &x.sentences)
if let Some(vectorset_sentences) = paragraph.vectorsets_sentences.get(vectorset) {
Some(&vectorset_sentences.sentences)
} else if self.fallback_to_default_vectorset {
Some(&paragraph.sentences)
} else {
None
}
} else {
// Default vectors index (no vectorset)
Some(&paragraph.sentences)
Expand Down
10 changes: 7 additions & 3 deletions nucliadb_node/src/shards/shard_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,15 +361,19 @@ impl ShardWriter {
};

let mut vector_tasks = vec![];
let vectorset_count = indexes.vectors_indexes.len();
for (vectorset, vector_writer) in indexes.vectors_indexes.iter_mut() {
vector_tasks.push(|| {
run_with_telemetry(info_span!(parent: &span, "vector set_resource"), || {
debug!("Vector service starts set_resource");

let vectorset_resource = match vectorset.as_str() {
"" | DEFAULT_VECTORS_INDEX_NAME => (&resource).into(),
vectorset => {
nucliadb_core::vectors::ResourceWrapper::new_vectorset_resource(&resource, vectorset)
}
vectorset => nucliadb_core::vectors::ResourceWrapper::new_vectorset_resource(
&resource,
vectorset,
vectorset_count == 1,
),
};
let result = vector_writer.set_resource(vectorset_resource);
debug!("Vector service ends set_resource");
Expand Down
18 changes: 13 additions & 5 deletions nucliadb_telemetry/src/nucliadb_telemetry/fastapi/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,15 +82,17 @@ def get(self, carrier: dict, key: str) -> typing.Optional[typing.List[str]]:
# ASGI header keys are in lower case
key = key.lower()
decoded = [
_value.decode("utf8") for (_key, _value) in headers if _key.decode("utf8").lower() == key
_value.decode("utf8", errors="replace")
for (_key, _value) in headers
if _key.decode("utf8", errors="replace").lower() == key
]
if not decoded:
return None
return decoded

def keys(self, carrier: dict) -> typing.List[str]:
headers = carrier.get("headers") or []
return [_key.decode("utf8") for (_key, _) in headers]
return [_key.decode("utf8", errors="replace") for (_key, _) in headers]


asgi_getter = ASGIGetter()
Expand Down Expand Up @@ -125,7 +127,7 @@ def collect_request_attributes(scope):
query_string = scope.get("query_string")
if query_string and http_url:
if isinstance(query_string, bytes):
query_string = query_string.decode("utf8")
query_string = query_string.decode("utf8", errors="replace")
http_url += "?" + urllib.parse.unquote(query_string)

result = {
Expand Down Expand Up @@ -167,7 +169,10 @@ def collect_custom_request_headers_attributes(scope):
)

# Decode headers before processing.
headers = {_key.decode("utf8"): _value.decode("utf8") for (_key, _value) in scope.get("headers")}
headers = {
_key.decode("utf8", errors="replace"): _value.decode("utf8", errors="replace")
for (_key, _value) in scope.get("headers")
}

return sanitize.sanitize_header_values(
headers,
Expand All @@ -186,7 +191,10 @@ def collect_custom_response_headers_attributes(message):
)

# Decode headers before processing.
headers = {_key.decode("utf8"): _value.decode("utf8") for (_key, _value) in message.get("headers")}
headers = {
_key.decode("utf8", errors="replace"): _value.decode("utf8", errors="replace")
for (_key, _value) in message.get("headers")
}

return sanitize.sanitize_header_values(
headers,
Expand Down
10 changes: 9 additions & 1 deletion nucliadb_telemetry/tests/unit/fastapi/test_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
import pytest
from opentelemetry.trace import format_trace_id

from nucliadb_telemetry.fastapi.tracing import CaptureTraceIdMiddleware
from nucliadb_telemetry.fastapi.tracing import (
CaptureTraceIdMiddleware,
collect_custom_request_headers_attributes,
)


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -59,3 +62,8 @@ async def test_capture_trace_id_middleware_appends_trace_id_header_to_exposed(tr
response = await mdw.dispatch(request, call_next)

assert response.headers["Access-Control-Expose-Headers"] == "Foo-Bar,X-Header,X-NUCLIA-TRACE-ID"


def test_collect_custom_request_headers_attributes():
scope = {"headers": [[b"x-filename", b"Synth\xe8ses\\3229-navigation.pdf"]]}
collect_custom_request_headers_attributes(scope)
2 changes: 1 addition & 1 deletion nucliadb_utils/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ nats-py[nkeys]>=2.6.0
PyNaCl
pyjwt>=2.4.0
memorylru>=1.1.2
mrflagly
mrflagly>=0.2.9

# automatically bumped during release
nucliadb-protos
Expand Down
Loading

0 comments on commit fcc7435

Please sign in to comment.