Skip to content

Commit

Permalink
Merge pull request #326 from aurelio-labs/james/async-op
Browse files Browse the repository at this point in the history
feat: async support for pinecone and openai
  • Loading branch information
jamescalam authored Jun 13, 2024
2 parents 1916fec + 80b6fc1 commit 4da3da1
Show file tree
Hide file tree
Showing 12 changed files with 1,388 additions and 14 deletions.
574 changes: 574 additions & 0 deletions docs/indexes/pinecone_async.ipynb

Large diffs are not rendered by default.

430 changes: 429 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "semantic-router"
version = "0.0.46"
version = "0.0.47"
description = "Super fast semantic router for AI decision making"
authors = [
"James Briggs <[email protected]>",
Expand Down Expand Up @@ -41,6 +41,7 @@ google-cloud-aiplatform = {version = "^1.45.0", optional = true}
requests-mock = "^1.12.1"
boto3 = { version = "^1.34.98", optional = true }
botocore = {version = "^1.34.110", optional = true}
aiohttp = "^3.9.5"

[tool.poetry.extras]
hybrid = ["pinecone-text"]
Expand Down
13 changes: 10 additions & 3 deletions semantic_router/encoders/base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,22 @@
from typing import Any, List
from typing import Any, Coroutine, List, Optional

from pydantic.v1 import BaseModel, Field
from pydantic.v1 import BaseModel, Field, validator


class BaseEncoder(BaseModel):
name: str
score_threshold: float
score_threshold: Optional[float] = None
type: str = Field(default="base")

class Config:
arbitrary_types_allowed = True

@validator("score_threshold", pre=True, always=True)
def set_score_threshold(cls, v):
return float(v) if v is not None else None

def __call__(self, docs: List[Any]) -> List[List[float]]:
raise NotImplementedError("Subclasses must implement this method")

def acall(self, docs: List[Any]) -> Coroutine[Any, Any, List[List[float]]]:
raise NotImplementedError("Subclasses must implement this method")
72 changes: 67 additions & 5 deletions semantic_router/encoders/openai.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from asyncio import sleep as asleep
import os
from time import sleep
from typing import Any, List, Optional, Union
Expand All @@ -17,19 +18,26 @@

model_configs = {
"text-embedding-ada-002": EncoderInfo(
name="text-embedding-ada-002", token_limit=8192
name="text-embedding-ada-002",
token_limit=8192,
threshold=0.82,
),
"text-embedding-3-small": EncoderInfo(
name="text-embedding-3-small", token_limit=8192
name="text-embedding-3-small",
token_limit=8192,
threshold=0.3,
),
"text-embedding-3-large": EncoderInfo(
name="text-embedding-3-large", token_limit=8192
name="text-embedding-3-large",
token_limit=8192,
threshold=0.3,
),
}


class OpenAIEncoder(BaseEncoder):
client: Optional[openai.Client]
async_client: Optional[openai.AsyncClient]
dimensions: Union[int, NotGiven] = NotGiven()
token_limit: int = 8192 # default value, should be replaced by config
_token_encoder: Any = PrivateAttr()
Expand All @@ -41,12 +49,24 @@ def __init__(
openai_base_url: Optional[str] = None,
openai_api_key: Optional[str] = None,
openai_org_id: Optional[str] = None,
score_threshold: float = 0.82,
score_threshold: Optional[float] = None,
dimensions: Union[int, NotGiven] = NotGiven(),
):
if name is None:
name = EncoderDefault.OPENAI.value["embedding_model"]
super().__init__(name=name, score_threshold=score_threshold)
if score_threshold is None and name in model_configs:
set_score_threshold = model_configs[name].threshold
elif score_threshold is None:
logger.warning(
f"Score threshold not set for model: {name}. Using default value."
)
set_score_threshold = 0.82
else:
set_score_threshold = score_threshold
super().__init__(
name=name,
score_threshold=set_score_threshold,
)
api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
base_url = openai_base_url or os.getenv("OPENAI_BASE_URL")
openai_org_id = openai_org_id or os.getenv("OPENAI_ORG_ID")
Expand All @@ -56,6 +76,9 @@ def __init__(
self.client = openai.Client(
base_url=base_url, api_key=api_key, organization=openai_org_id
)
self.async_client = openai.AsyncClient(
base_url=base_url, api_key=api_key, organization=openai_org_id
)
except Exception as e:
raise ValueError(
f"OpenAI API client failed to initialize. Error: {e}"
Expand Down Expand Up @@ -126,3 +149,42 @@ def _truncate(self, text: str) -> str:
logger.info(f"Trunc length: {len(self._token_encoder.encode(text))}")
return text
return text

async def acall(self, docs: List[str], truncate: bool = True) -> List[List[float]]:
if self.async_client is None:
raise ValueError("OpenAI async client is not initialized.")
embeds = None
error_message = ""

if truncate:
# check if any document exceeds token limit and truncate if so
docs = [self._truncate(doc) for doc in docs]

# Exponential backoff
for j in range(1, 7):
try:
embeds = await self.async_client.embeddings.create(
input=docs,
model=self.name,
dimensions=self.dimensions,
)
if embeds.data:
break
except OpenAIError as e:
await asleep(2**j)
error_message = str(e)
logger.warning(f"Retrying in {2**j} seconds...")
except Exception as e:
logger.error(f"OpenAI API call failed. Error: {error_message}")
raise ValueError(f"OpenAI API call failed. Error: {e}") from e

if (
not embeds
or not isinstance(embeds, CreateEmbeddingResponse)
or not embeds.data
):
logger.info(f"Returned embeddings: {embeds}")
raise ValueError(f"No embeddings returned. Error: {error_message}")

embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings
51 changes: 49 additions & 2 deletions semantic_router/encoders/zure.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from asyncio import sleep as asleep
import os
from time import sleep
from typing import List, Optional, Union
Expand All @@ -14,6 +15,7 @@

class AzureOpenAIEncoder(BaseEncoder):
client: Optional[openai.AzureOpenAI] = None
async_client: Optional[openai.AsyncAzureOpenAI] = None
dimensions: Union[int, NotGiven] = NotGiven()
type: str = "azure"
api_key: Optional[str] = None
Expand Down Expand Up @@ -77,7 +79,14 @@ def __init__(
api_key=str(self.api_key),
azure_endpoint=str(self.azure_endpoint),
api_version=str(self.api_version),
# _strict_response_validation=True,
)
self.async_client = openai.AsyncAzureOpenAI(
azure_deployment=(
str(self.deployment_name) if self.deployment_name else None
),
api_key=str(self.api_key),
azure_endpoint=str(self.azure_endpoint),
api_version=str(self.api_version),
)
except Exception as e:
raise ValueError(
Expand All @@ -86,7 +95,7 @@ def __init__(

def __call__(self, docs: List[str]) -> List[List[float]]:
if self.client is None:
raise ValueError("OpenAI client is not initialized.")
raise ValueError("Azure OpenAI client is not initialized.")
embeds = None
error_message = ""

Expand Down Expand Up @@ -121,3 +130,41 @@ def __call__(self, docs: List[str]) -> List[List[float]]:

embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings

async def acall(self, docs: List[str]) -> List[List[float]]:
if self.async_client is None:
raise ValueError("Azure OpenAI async client is not initialized.")
embeds = None
error_message = ""

# Exponential backoff
for j in range(3):
try:
embeds = await self.async_client.embeddings.create(
input=docs,
model=str(self.model),
dimensions=self.dimensions,
)
if embeds.data:
break
except OpenAIError as e:
# print full traceback
import traceback

traceback.print_exc()
await asleep(2**j)
error_message = str(e)
logger.warning(f"Retrying in {2**j} seconds...")
except Exception as e:
logger.error(f"Azure OpenAI API call failed. Error: {error_message}")
raise ValueError(f"Azure OpenAI API call failed. Error: {e}") from e

if (
not embeds
or not isinstance(embeds, CreateEmbeddingResponse)
or not embeds.data
):
raise ValueError(f"No embeddings returned. Error: {error_message}")

embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings
5 changes: 5 additions & 0 deletions semantic_router/hybrid_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def __init__(
aggregation: str = "sum",
):
self.encoder = encoder
if self.encoder.score_threshold is None:
raise ValueError(
"No score threshold provided for encoder. Please set the score threshold "
"in the encoder config."
)
self.score_threshold = self.encoder.score_threshold

if sparse_encoder is None:
Expand Down
12 changes: 12 additions & 0 deletions semantic_router/index/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ def query(
"""
raise NotImplementedError("This method should be implemented by subclasses.")

async def aquery(
self,
vector: np.ndarray,
top_k: int = 5,
route_filter: Optional[List[str]] = None,
) -> Tuple[np.ndarray, List[str]]:
"""
Search the index for the query_vector and return top_k results.
This method should be implemented by subclasses.
"""
raise NotImplementedError("This method should be implemented by subclasses.")

def delete_index(self):
"""
Deletes or resets the index.
Expand Down
Loading

0 comments on commit 4da3da1

Please sign in to comment.