Skip to content

Commit

Permalink
Merge pull request #358 from aurelio-labs/vittorio/357-sync-feature-n…
Browse files Browse the repository at this point in the history
…ot-handling-non-existing-local-routes-present-remotely

fix: handling non-existing local routes in sync
  • Loading branch information
jamescalam authored Jul 31, 2024
2 parents cbb685f + ac97a3c commit 6da452a
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 6 deletions.
43 changes: 39 additions & 4 deletions semantic_router/index/pinecone.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from semantic_router.index.base import BaseIndex
from semantic_router.utils.logger import logger
from semantic_router.route import Route


def clean_route_name(route_name: str) -> str:
Expand Down Expand Up @@ -203,6 +204,7 @@ async def _init_async_index(self, force_create: bool = False):

def _sync_index(self, local_routes: dict):
remote_routes = self.get_routes()

remote_dict: dict = {route: set() for route, _ in remote_routes}
for route, utterance in remote_routes:
remote_dict[route].add(utterance)
Expand All @@ -215,19 +217,27 @@ def _sync_index(self, local_routes: dict):

routes_to_add = []
routes_to_delete = []
layer_routes = {}

for route in all_routes:
local_utterances = local_dict.get(route, set())
remote_utterances = remote_dict.get(route, set())

if not local_utterances and not remote_utterances:
continue

if self.sync == "error":
if local_utterances != remote_utterances:
raise ValueError(
f"Synchronization error: Differences found in route '{route}'"
)
utterances_to_include: set = set()
if local_utterances:
layer_routes[route] = list(local_utterances)
elif self.sync == "remote":
utterances_to_include = set()
if remote_utterances:
layer_routes[route] = list(remote_utterances)
elif self.sync == "local":
utterances_to_include = local_utterances - remote_utterances
routes_to_delete.extend(
Expand All @@ -237,11 +247,17 @@ def _sync_index(self, local_routes: dict):
if utterance not in local_utterances
]
)
if local_utterances:
layer_routes[route] = list(local_utterances)
elif self.sync == "merge-force-remote":
if route in local_dict and route not in remote_dict:
utterances_to_include = local_utterances
if local_utterances:
layer_routes[route] = list(local_utterances)
else:
utterances_to_include = set()
if remote_utterances:
layer_routes[route] = list(remote_utterances)
elif self.sync == "merge-force-local":
if route in local_dict:
utterances_to_include = local_utterances - remote_utterances
Expand All @@ -252,10 +268,18 @@ def _sync_index(self, local_routes: dict):
if utterance not in local_utterances
]
)
if local_utterances:
layer_routes[route] = local_utterances
else:
utterances_to_include = set()
if remote_utterances:
layer_routes[route] = list(remote_utterances)
elif self.sync == "merge":
utterances_to_include = local_utterances - remote_utterances
if local_utterances or remote_utterances:
layer_routes[route] = list(
remote_utterances.union(local_utterances)
)
else:
raise ValueError("Invalid sync mode specified")

Expand All @@ -272,7 +296,7 @@ def _sync_index(self, local_routes: dict):
]
)

return routes_to_add, routes_to_delete
return routes_to_add, routes_to_delete, layer_routes

def _batch_upsert(self, batch: List[Dict]):
"""Helper method for upserting a single batch of records."""
Expand Down Expand Up @@ -308,8 +332,8 @@ def _add_and_sync(
routes: List[str],
utterances: List[str],
batch_size: int = 100,
):
"""Add vectors to Pinecone in batches."""
) -> List[Route]:
"""Add vectors to Pinecone in batches and return the overall updated list of Route objects."""
if self.index is None:
self.dimensions = self.dimensions or len(embeddings[0])
self.index = self._init_index(force_create=True)
Expand All @@ -320,7 +344,15 @@ def _add_and_sync(
"embeddings": embeddings,
}
if self.sync is not None:
data_to_upsert, data_to_delete = self._sync_index(local_routes=local_routes)
data_to_upsert, data_to_delete, layer_routes_dict = self._sync_index(
local_routes=local_routes
)

layer_routes = [
Route(name=route, utterances=layer_routes_dict[route])
for route in layer_routes_dict.keys()
]

routes_to_delete: dict = {}
for route, utterance in data_to_delete:
routes_to_delete.setdefault(route, []).append(utterance)
Expand All @@ -335,6 +367,7 @@ def _add_and_sync(
]
if ids_to_delete and self.index:
self.index.delete(ids=ids_to_delete)

else:
data_to_upsert = [
(vector, route, utterance)
Expand All @@ -350,6 +383,8 @@ def _add_and_sync(
batch = vectors_to_upsert[i : i + batch_size]
self._batch_upsert(batch)

return layer_routes

def _get_route_ids(self, route_name: str):
clean_route = clean_route_name(route_name)
ids, _ = self._get_all(prefix=f"{clean_route}#")
Expand Down
40 changes: 38 additions & 2 deletions semantic_router/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,20 @@ def __init__(
if route.score_threshold is None:
route.score_threshold = self.score_threshold
# if routes list has been passed, we initialize index now
if len(self.routes) > 0:
if self.index.sync:
# initialize index now
if len(self.routes) > 0:
self._add_and_sync_routes(routes=self.routes)
else:
dummy_embedding = self.encoder(["dummy"])

layer_routes = self.index._add_and_sync(
embeddings=dummy_embedding,
routes=[],
utterances=[],
)
self._set_layer_routes(layer_routes)
elif len(self.routes) > 0:
self._add_routes(routes=self.routes)

def check_for_matching_routes(self, top_class: str) -> Optional[Route]:
Expand Down Expand Up @@ -385,6 +397,14 @@ def _check_threshold(self, scores: List[float], route: Optional[Route]) -> bool:
)
return self._pass_threshold(scores, threshold)

def _set_layer_routes(self, new_routes: List[Route]):
"""
Set and override the current routes with a new list of routes.
:param new_routes: List of Route objects to set as the current routes.
"""
self.routes = new_routes

def __str__(self):
return (
f"RouteLayer(encoder={self.encoder}, "
Expand Down Expand Up @@ -471,11 +491,27 @@ def _add_routes(self, routes: List[Route]):
# create route array
route_names = [route.name for route in routes for _ in route.utterances]
# add everything to the index
self.index._add_and_sync(
self.index.add(
embeddings=embedded_utterances,
routes=route_names,
utterances=all_utterances,
)

def _add_and_sync_routes(self, routes: List[Route]):
# create embeddings for all routes and sync at startup with remote ones based on sync setting
all_utterances = [
utterance for route in routes for utterance in route.utterances
]
embedded_utterances = self.encoder(all_utterances)
# create route array
route_names = [route.name for route in routes for _ in route.utterances]
# add everything to the index
layer_routes = self.index._add_and_sync(
embeddings=embedded_utterances,
routes=route_names,
utterances=all_utterances,
)
self._set_layer_routes(layer_routes)

def _encode(self, text: str) -> Any:
"""Given some text, encode it."""
Expand Down

0 comments on commit 6da452a

Please sign in to comment.