diff --git a/server/api/debug/redis.py b/server/api/debug/redis.py index 241c0c8..2107d19 100644 --- a/server/api/debug/redis.py +++ b/server/api/debug/redis.py @@ -6,8 +6,8 @@ from server.config import Config from server.databases.redis.wrapper import RedisAsync -from server.dependencies import redis_client -from server.features.embeddings import Embedding +from server.dependencies import embedder, redis_client +from server.features.embeddings import Embedder from server.schemas.v1 import Query @@ -19,7 +19,10 @@ class RedisController(Controller): """ path = '/redis' - dependencies = {'redis': Provide(redis_client)} + dependencies = { + 'redis': Provide(redis_client), + 'embedder': Provide(embedder), + } @delete() async def delete_index(self, redis: Annotated[RedisAsync, Dependency()], recreate: bool = False) -> None: @@ -42,6 +45,7 @@ async def delete_index(self, redis: Annotated[RedisAsync, Dependency()], recreat async def search( self, redis: Annotated[RedisAsync, Dependency()], + embedder: Annotated[Embedder, Dependency()], chat_id: str, data: Query, search_size: Annotated[int, Parameter(gt=0)] = 1, @@ -51,4 +55,4 @@ async def search( ------- an endpoint for searching the Redis vector database """ - return await redis.search(chat_id, Embedding().encode_query(data.query), search_size) + return await redis.search(chat_id, embedder.encode_query(data.query), search_size) diff --git a/server/api/v1/chat.py b/server/api/v1/chat.py index 7b9288d..bd299ba 100644 --- a/server/api/v1/chat.py +++ b/server/api/v1/chat.py @@ -10,9 +10,9 @@ from server.databases.redis.features import store_chunks from server.databases.redis.wrapper import RedisAsync -from server.dependencies.redis import redis_client +from server.dependencies import embedder, redis_client from server.features.chunking import SentenceSplitter, chunk_document -from server.features.embeddings import Embedding +from server.features.embeddings import Embedder from server.features.extraction import extract_documents_from_pdfs from server.features.question_answering import question_answering from server.schemas.v1 import Answer, Chat, Files, Query @@ -27,7 +27,10 @@ class ChatController(Controller): """ path = '/chats' - dependencies = {'redis': Provide(redis_client)} + dependencies = { + 'redis': Provide(redis_client), + 'embedder': Provide(embedder), + } @get() async def create_chat(self) -> Chat: @@ -75,6 +78,7 @@ async def upload_files( self, state: AppState, redis: Annotated[RedisAsync, Dependency()], + embedder: Annotated[Embedder, Dependency()], chat_id: str, data: Annotated[list[UploadFile], Body(media_type=RequestEncodingType.MULTI_PART)], ) -> Files: @@ -83,7 +87,6 @@ async def upload_files( ------- an endpoint for uploading files to a chat """ - embedder = Embedding() text_splitter = SentenceSplitter(state.chat.tokeniser, chunk_size=128, chunk_overlap=0) responses = [] @@ -109,6 +112,7 @@ async def query( self, state: AppState, redis: Annotated[RedisAsync, Dependency()], + embedder: Annotated[Embedder, Dependency()], chat_id: str, data: Query, search_size: Annotated[int, Parameter(ge=0)] = 0, @@ -119,9 +123,7 @@ async def query( ------- the `/query` route provides an endpoint for performning retrieval-augmented generation """ - context = ( - '' if not search_size else await redis.search(chat_id, Embedding().encode_query(data.query), search_size) - ) + context = '' if not search_size else await redis.search(chat_id, embedder.encode_query(data.query), search_size) message_history = await redis.get_messages(chat_id) messages = await question_answering(data.query, context, message_history, state.chat.query) diff --git a/server/dependencies/__init__.py b/server/dependencies/__init__.py index d0351b2..417b725 100644 --- a/server/dependencies/__init__.py +++ b/server/dependencies/__init__.py @@ -1 +1,2 @@ +from server.dependencies.embedder import embedder as embedder from server.dependencies.redis import redis_client as redis_client diff --git a/server/dependencies/embedder.py b/server/dependencies/embedder.py new file mode 100644 index 0000000..8ffd9ec --- /dev/null +++ b/server/dependencies/embedder.py @@ -0,0 +1,21 @@ +from typing import Iterator + +from server.features.embeddings import Embedder + + +def embedder() -> Iterator[Embedder]: + """ + Summary + ------- + load the embeddings model + + Returns + ------- + embedding (Embedding): the embeddings model + """ + embedder = Embedder() + + try: + yield embedder + finally: + del embedder diff --git a/server/features/chat/model.py b/server/features/chat/model.py index c067de8..7fede68 100644 --- a/server/features/chat/model.py +++ b/server/features/chat/model.py @@ -3,7 +3,7 @@ from server.config import Config from server.features.chat.types import Message -from server.helpers import huggingface_download +from server.utils import huggingface_download class ChatModel: diff --git a/server/features/embeddings/__init__.py b/server/features/embeddings/__init__.py index 4b936d4..9c4d990 100644 --- a/server/features/embeddings/__init__.py +++ b/server/features/embeddings/__init__.py @@ -1 +1 @@ -from server.features.embeddings.embedding import Embedding as Embedding +from server.features.embeddings.embedding import Embedder as Embedder diff --git a/server/features/embeddings/embedding.py b/server/features/embeddings/embedding.py index 16849b5..ea4658f 100644 --- a/server/features/embeddings/embedding.py +++ b/server/features/embeddings/embedding.py @@ -1,11 +1,11 @@ -from huggingface_hub import snapshot_download from sentence_transformers import SentenceTransformer from torch import device from server.features.embeddings.flag_embedding import FlagEmbedding +from server.utils import huggingface_download -class Embedding(SentenceTransformer): +class Embedder(SentenceTransformer): """ Summary ------- @@ -20,12 +20,12 @@ class Embedding(SentenceTransformer): encode a sentence for searching relevant passages """ - def __init__(self, *, force_download: bool = False): + def __init__(self): model_name = 'bge-base-en-v1.5' super().__init__(f'BAAI/{model_name}') self.cached_device = super().device # type: ignore - model_path = snapshot_download(f'winstxnhdw/{model_name}-ct2', local_files_only=not force_download) + model_path = huggingface_download(f'winstxnhdw/{model_name}-ct2') self[0] = FlagEmbedding(self[0], model_path, 'auto') @property diff --git a/server/helpers/__init__.py b/server/helpers/__init__.py deleted file mode 100644 index b1fb832..0000000 --- a/server/helpers/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from server.helpers.network import huggingface_download as huggingface_download diff --git a/server/helpers/network/__init__.py b/server/helpers/network/__init__.py deleted file mode 100644 index 960f665..0000000 --- a/server/helpers/network/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from server.helpers.network.huggingface_download import ( - huggingface_download as huggingface_download, -) diff --git a/server/lifespans/download_embeddings.py b/server/lifespans/download_embeddings.py index db0aad5..4d5514f 100644 --- a/server/lifespans/download_embeddings.py +++ b/server/lifespans/download_embeddings.py @@ -3,11 +3,11 @@ from litestar import Litestar -from server.helpers import huggingface_download +from server.utils import huggingface_download @asynccontextmanager -async def download_embeddings(app: Litestar) -> AsyncIterator[None]: +async def download_embeddings(_: Litestar) -> AsyncIterator[None]: """ Summary ------- diff --git a/server/utils/__init__.py b/server/utils/__init__.py new file mode 100644 index 0000000..b0bdaec --- /dev/null +++ b/server/utils/__init__.py @@ -0,0 +1 @@ +from server.utils.network import huggingface_download as huggingface_download diff --git a/server/utils/network/__init__.py b/server/utils/network/__init__.py new file mode 100644 index 0000000..2d5eaaa --- /dev/null +++ b/server/utils/network/__init__.py @@ -0,0 +1,3 @@ +from server.utils.network.huggingface_download import ( + huggingface_download as huggingface_download, +) diff --git a/server/helpers/network/has_internet_access.py b/server/utils/network/has_internet_access.py similarity index 100% rename from server/helpers/network/has_internet_access.py rename to server/utils/network/has_internet_access.py diff --git a/server/helpers/network/huggingface_download.py b/server/utils/network/huggingface_download.py similarity index 86% rename from server/helpers/network/huggingface_download.py rename to server/utils/network/huggingface_download.py index 1616cb6..7b52185 100644 --- a/server/helpers/network/huggingface_download.py +++ b/server/utils/network/huggingface_download.py @@ -1,6 +1,6 @@ from huggingface_hub import snapshot_download -from server.helpers.network.has_internet_access import has_internet_access +from server.utils.network.has_internet_access import has_internet_access def huggingface_download(repository: str) -> str: diff --git a/tests/test_embedding.py b/tests/test_embedding.py index 765a0ef..ca664cf 100644 --- a/tests/test_embedding.py +++ b/tests/test_embedding.py @@ -5,14 +5,14 @@ from numpy import array_equal from pytest import fixture -from server.features.embeddings import Embedding +from server.features.embeddings import Embedder type Text = Literal['Hello world!'] @fixture() def embedding(): - yield Embedding(force_download=True) + yield Embedder(force_download=True) @fixture() @@ -20,13 +20,13 @@ def text(): yield 'Hello world!' -def test_encodings(embedding: Embedding, text: Text): +def test_encodings(embedding: Embedder, text: Text): assert array_equal(embedding.encode_query(text), embedding.encode_normalise(text)) is False -def test_encode_query(embedding: Embedding, text: Text): +def test_encode_query(embedding: Embedder, text: Text): assert len(embedding.encode_query(text)) > 0 -def test_encode_normalise(embedding: Embedding, text: Text): +def test_encode_normalise(embedding: Embedder, text: Text): assert len(embedding.encode_normalise(text)) > 0 diff --git a/tests/test_has_internet_access.py b/tests/test_has_internet_access.py index a7baa28..d6c7604 100644 --- a/tests/test_has_internet_access.py +++ b/tests/test_has_internet_access.py @@ -2,7 +2,7 @@ from pytest import mark -from server.helpers.network.has_internet_access import has_internet_access +from server.utils.network.has_internet_access import has_internet_access @mark.parametrize(