Skip to content

Commit

Permalink
refactor: add chat model to application state
Browse files Browse the repository at this point in the history
  • Loading branch information
winstxnhdw committed Sep 22, 2024
1 parent 688c3d9 commit 771815c
Show file tree
Hide file tree
Showing 11 changed files with 205 additions and 200 deletions.
26 changes: 14 additions & 12 deletions server/api/debug/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

from litestar import Controller, post

from server.features.chat import Chat
from server.features.chat.types import Message
from server.schemas.v1 import Benchmark, Generate, Query
from server.state import AppState


class LLMController(Controller):
Expand All @@ -17,38 +17,40 @@ class LLMController(Controller):
path = '/llm'

@post()
async def generate(self, request: Generate) -> str:
async def generate(self, state: AppState, data: Generate) -> str:
"""
Summary
-------
an endpoint for generating text directly from the LLM model
"""
prompt = Chat.tokeniser.apply_chat_template(
[{'role': 'user', 'content': request.instruction}],
chat = state.chat

prompt = chat.tokeniser.apply_chat_template(
[{'role': 'user', 'content': data.instruction}],
tokenize=False,
add_generation_prompt=True,
)

return await Chat.generate(Chat.tokeniser(prompt).tokens())
return await chat.generate(chat.tokeniser(prompt).tokens())

@post('/benchmark')
async def benchmark(self, data: Query) -> Benchmark:
async def benchmark(self, state: AppState, data: Query) -> Benchmark:
"""
Summary
-------
an endpoint for benchmarking the LLM model
"""
chat = state.chat
message: Message = {'role': 'user', 'content': data.query}

prompt = Chat.tokeniser.apply_chat_template([message], add_generation_prompt=True, tokenize=False)
tokenised_prompt = Chat.tokeniser(prompt).tokens()
prompt = chat.tokeniser.apply_chat_template([message], add_generation_prompt=True, tokenize=False)
tokenised_prompt = chat.tokeniser(prompt).tokens()

start = perf_counter()
response = await Chat.generate(tokenised_prompt)
response = await chat.generate(tokenised_prompt)
total_time = perf_counter() - start

output_tokens = Chat.tokeniser(response).tokens()
total_tokens = len(tokenised_prompt) + len(Chat.static_prompt) + len(output_tokens)
output_tokens = chat.tokeniser(response).tokens()
total_tokens = len(tokenised_prompt) + len(chat) + len(output_tokens)

return Benchmark(
response=response,
Expand Down
8 changes: 5 additions & 3 deletions server/api/v1/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from server.databases.redis.features import store_chunks
from server.databases.redis.wrapper import RedisAsync
from server.dependencies.redis import redis_client
from server.features.chat import Chat
from server.features.chunking import SentenceSplitter, chunk_document
from server.features.embeddings import Embedding
from server.features.extraction import extract_documents_from_pdfs
from server.features.question_answering import question_answering
from server.schemas.v1 import Answer, Chat, Files, Query
from server.state import AppState


class ChatController(Controller):
Expand Down Expand Up @@ -73,6 +73,7 @@ async def delete_chat_file(
@put('/{chat_id:str}/files')
async def upload_files(
self,
state: AppState,
redis: Annotated[RedisAsync, Dependency()],
chat_id: str,
data: Annotated[list[UploadFile], Body(media_type=RequestEncodingType.MULTI_PART)],
Expand All @@ -83,7 +84,7 @@ async def upload_files(
an endpoint for uploading files to a chat
"""
embedder = Embedding()
text_splitter = SentenceSplitter(Chat.tokeniser, chunk_size=128, chunk_overlap=0)
text_splitter = SentenceSplitter(state.chat.tokeniser, chunk_size=128, chunk_overlap=0)
responses = []

chunk_generator = store_chunks(
Expand All @@ -106,6 +107,7 @@ async def upload_files(
@post('/{chat_id:str}/query')
async def query(
self,
state: AppState,
redis: Annotated[RedisAsync, Dependency()],
chat_id: str,
data: Query,
Expand All @@ -122,7 +124,7 @@ async def query(
)

message_history = await redis.get_messages(chat_id)
messages = await question_answering(data.query, context, message_history, Chat.query)
messages = await question_answering(data.query, context, message_history, state.chat.query)

if store_query:
await redis.save_messages(chat_id, messages)
Expand Down
2 changes: 1 addition & 1 deletion server/features/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from server.features.chat.chat import Chat as Chat
from server.features.chat.model import get_chat_model as get_chat_model
153 changes: 0 additions & 153 deletions server/features/chat/chat.py

This file was deleted.

Loading

0 comments on commit 771815c

Please sign in to comment.