Skip to content

Commit

Permalink
Merge pull request #427 from aurelio-labs/vittorio/qdrant-query
Browse files Browse the repository at this point in the history
refactor: Use Qdrant's Query API
  • Loading branch information
jamescalam authored Sep 21, 2024
2 parents 3e6bd22 + 3e02edb commit b361dc9
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 14 deletions.
2 changes: 1 addition & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ torchvision = { version = ">=0.17.0,<0.18.0", optional = true}
pillow = { version = ">=10.2.0,<11.0.0", optional = true}
tiktoken = ">=0.6.0,<1.0.0"
matplotlib = { version = "^3.8.3", optional = true}
qdrant-client = {version = "^1.8.0", optional = true}
qdrant-client = {version = "^1.11.1", optional = true}
google-cloud-aiplatform = {version = "^1.45.0", optional = true}
requests-mock = "^1.12.1"
boto3 = { version = "^1.34.98", optional = true }
Expand Down
27 changes: 15 additions & 12 deletions semantic_router/index/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from semantic_router.index.base import BaseIndex
from semantic_router.schema import Metric

from semantic_router.utils.logger import logger

DEFAULT_COLLECTION_NAME = "semantic-router-index"
Expand Down Expand Up @@ -97,7 +96,7 @@ def __init__(self, **kwargs):

def _initialize_clients(self):
try:
from qdrant_client import QdrantClient, AsyncQdrantClient
from qdrant_client import AsyncQdrantClient, QdrantClient

sync_client = QdrantClient(
location=self.location,
Expand Down Expand Up @@ -264,7 +263,7 @@ def query(
top_k: int = 5,
route_filter: Optional[List[str]] = None,
) -> Tuple[np.ndarray, List[str]]:
from qdrant_client import models, QdrantClient
from qdrant_client import QdrantClient, models

self.client: QdrantClient
filter = None
Expand All @@ -278,15 +277,17 @@ def query(
]
)

results = self.client.search(
results = self.client.query_points(
self.index_name,
query_vector=vector,
query=vector,
limit=top_k,
with_payload=True,
query_filter=filter,
)
scores = [result.score for result in results]
route_names = [result.payload[SR_ROUTE_PAYLOAD_KEY] for result in results]
scores = [result.score for result in results.points]
route_names = [
result.payload[SR_ROUTE_PAYLOAD_KEY] for result in results.points
]
return np.array(scores), route_names

async def aquery(
Expand All @@ -295,7 +296,7 @@ async def aquery(
top_k: int = 5,
route_filter: Optional[List[str]] = None,
) -> Tuple[np.ndarray, List[str]]:
from qdrant_client import models, AsyncQdrantClient
from qdrant_client import AsyncQdrantClient, models

self.aclient: Optional[AsyncQdrantClient]
if self.aclient is None:
Expand All @@ -313,15 +314,17 @@ async def aquery(
]
)

results = await self.aclient.search(
results = await self.aclient.query_points(
self.index_name,
query_vector=vector,
query=vector,
limit=top_k,
with_payload=True,
query_filter=filter,
)
scores = [result.score for result in results]
route_names = [result.payload[SR_ROUTE_PAYLOAD_KEY] for result in results]
scores = [result.score for result in results.points]
route_names = [
result.payload[SR_ROUTE_PAYLOAD_KEY] for result in results.points
]
return np.array(scores), route_names

def aget_routes(self):
Expand Down

0 comments on commit b361dc9

Please sign in to comment.