-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
55a139f
commit b734844
Showing
8 changed files
with
320 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
.env |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# a dockerfile for a fastapi server | ||
|
||
# Use the official Python image. | ||
# https://hub.docker.com/_/python | ||
FROM python:3.8-slim | ||
|
||
# Copy local code to the container image. | ||
ENV APP_HOME /app | ||
WORKDIR $APP_HOME | ||
ARG COMMITHASH | ||
ENV COMMITHASH=$COMMITHASH | ||
COPY requirements.txt ./ | ||
|
||
# Install production dependencies. | ||
RUN pip install --upgrade pip | ||
RUN pip install -r requirements.txt | ||
|
||
COPY . ./ | ||
|
||
# Run the fastapi service | ||
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "4000"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Private LLM API | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#!/bin/sh | ||
source ./.env | ||
REPOSITORY=private-llm-fastapi-server | ||
IMAGE_TAG=latest | ||
docker build --build-arg="COMMITHASH=localtest" -t $REPOSITORY:$IMAGE_TAG . | ||
|
||
docker run --rm -p 4001:4000 --env-file ./.env $REPOSITORY:$IMAGE_TAG |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
""" | ||
contextualize.py | ||
Maintainer: Wes Kennedy | ||
Description: The contextualize module allows us to write app specific queries to help build our database. | ||
""" | ||
|
||
import db | ||
from sqlalchemy import * | ||
|
||
class Contextualizer(): | ||
def __init__(): | ||
pass | ||
|
||
def customer_lookup_byid(customer_id): | ||
""" | ||
Takes a customer id and returns the customer's name | ||
""" | ||
pass | ||
|
||
def customer_lookup_byname(customer_name): | ||
""" | ||
Takes a customer name and returns the customer's id | ||
""" | ||
pass | ||
|
||
def customer_lookup_byemail(customer_email): | ||
""" | ||
Takes a customer email and returns the customer's id | ||
""" | ||
pass | ||
|
||
def customer_previous_orders(db_conn, customer_id): | ||
""" | ||
Takes a customer id and returns a list of previous orders | ||
""" | ||
orders = [] | ||
query = text(f"SELECT * FROM orders WHERE customer_id = {customer_id}") | ||
response = db.query_wrapper(db_conn, query) | ||
|
||
return orders |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import os | ||
# add the ability to query a mysql server using sql_alchemy | ||
from sqlalchemy import * | ||
from sqlalchemy_utils import database_exists | ||
|
||
MYSQL_HOST = os.getenv("MYSQL_HOST") | ||
MYSQL_PORT = os.getenv("MYSQL_PORT") | ||
MYSQL_USER = os.getenv("MYSQL_USER") | ||
MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD") | ||
MYSQL_DATABASE = os.getenv("MYSQL_DATABASE") | ||
|
||
|
||
|
||
def init_connect(): | ||
# Create the connection string | ||
connection_string = f"mysql+pymysql://{MYSQL_USER}:{MYSQL_PASSWORD}@{MYSQL_HOST}:{MYSQL_PORT}" | ||
full_engine = f"{connection_string}/{MYSQL_DATABASE}" | ||
engine = create_engine(connection_string) | ||
if database_exists(full_engine): | ||
return engine | ||
else: | ||
with engine.connect() as conn: | ||
init_db(conn) | ||
return engine | ||
|
||
# Create the engine | ||
def simple_connect(): | ||
pass | ||
|
||
def init(): | ||
db_conn = init_connect() | ||
use_db(db_conn) | ||
init_chat_table(db_conn) | ||
return db_conn | ||
|
||
def init_db(db_conn): | ||
query_create_db = text(f"CREATE DATABASE {MYSQL_DATABASE}") #create db | ||
with db_conn.connect() as conn: | ||
conn.execute(query_create_db) | ||
|
||
def use_db(db_conn): | ||
use_db = text(f"USE {MYSQL_DATABASE}") | ||
with db_conn.connect() as conn: | ||
conn.execute(use_db) | ||
|
||
def init_chat_table(db_conn): | ||
create_chat_table = text(f"CREATE TABLE IF NOT EXISTS messages (_id INT AUTO_INCREMENT,conversation_id CHAR(255),message TEXT,sender VARCHAR(50),timestamp TIMESTAMP,chat_context JSON,user_context TEXT,embedding BLOB NOT NULL, PRIMARY KEY (_id));") | ||
with db_conn.connect() as conn: | ||
conn.execute(create_chat_table) | ||
|
||
def query_wrapper(db_conn, query): | ||
with db_conn.connect() as conn: | ||
result = conn.execute(query) | ||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
# Create a FastAPI server and define the endpoints: /embedding, /chat as POST requests | ||
# /embedding: takes a text input and returns the embedding from a remote api call using requests | ||
# /chat: takes a text input and context and returns a response from a remote api call using requests | ||
# Note: the remote api calls are defined in the config file | ||
|
||
import os | ||
import sys | ||
import json | ||
import pandas as pd | ||
from typing import List, Optional, Dict | ||
from fastapi import FastAPI, HTTPException, Request, Response | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from pydantic import BaseModel | ||
from langchain.prompts import ( | ||
ChatPromptTemplate, | ||
MessagesPlaceholder, | ||
SystemMessagePromptTemplate, | ||
HumanMessagePromptTemplate, | ||
) | ||
from langchain.embeddings import SagemakerEndpointEmbeddings | ||
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler | ||
from langchain.llms import SagemakerEndpoint | ||
from langchain.chains import RetrievalQA | ||
from langchain.llms.sagemaker_endpoint import LLMContentHandler | ||
from contextualize import Contextualizer | ||
from langchain.chains import LLMChain | ||
from langchain.memory import ConversationSummaryMemory | ||
import db | ||
from langchain.memory import ConversationBufferMemory | ||
|
||
memory = ConversationBufferMemory() | ||
|
||
db_conn = db.init() | ||
|
||
### System Prompt | ||
system_prompt = """ | ||
You are a helpful customer service agent working for Kai Shoes. \n | ||
You will be chatting with a customer. \n | ||
Use context from their previous orders to help them make decisions. | ||
""" | ||
|
||
# Add the parent directory to the path to import the config file | ||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | ||
|
||
SAGEMAKER_ENDPOINT = os.getenv("SAGEMAKER_ENDPOINT") | ||
SAGEMAKER_ROLE = os.getenv("SAGEMAKER_ROLE") | ||
SAGEMAKER_REGION = os.getenv("SAGEMAKER_REGION") | ||
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID") | ||
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY") | ||
COMMITHASH = os.getenv("COMMITHASH") | ||
|
||
print(COMMITHASH) | ||
|
||
print(AWS_ACCESS_KEY_ID) | ||
print(AWS_SECRET_ACCESS_KEY) | ||
|
||
|
||
# Initialize the FastAPI server | ||
app = FastAPI( | ||
title="LLM API", | ||
description="API for the LLM project", | ||
version="0.1.0", | ||
docs_url="/", | ||
) | ||
|
||
## FastAPI Routes | ||
### /chat route | ||
|
||
""" | ||
Expected JSON: | ||
{ | ||
"text": "this is my message", | ||
"cust_id": "1234" | ||
} | ||
""" | ||
@app.post("/chat") | ||
async def chat(request: Request): | ||
""" | ||
Takes a text input and context and returns a response from a remote api call using requests | ||
""" | ||
# Get the request body | ||
body = await request.json() | ||
# Get the text input | ||
print(body) | ||
question = body.get("text") | ||
|
||
# Get the context | ||
context = body.get("cust_id") | ||
|
||
|
||
# Get the response from the LLMChain | ||
response = llm_prompt_run(context, question) | ||
# Return the response | ||
return {"response": response} | ||
|
||
@app.get("/test") | ||
async def root(): | ||
return {"message": "Hello World, I'm runnin on commit {}".format(COMMITHASH)} | ||
|
||
|
||
|
||
# SageMaker Endpoint Handler | ||
class ContentHandler(LLMContentHandler): | ||
content_type = "application/json" | ||
accepts = "application/json" | ||
|
||
def transform_input(self, prompt: str, model_kwargs: dict) -> bytes: | ||
# payload = { | ||
# "inputs": [ | ||
# { | ||
# "role": "system", | ||
# "content": system_prompt, | ||
# }, | ||
# {"role": "user", "content": prompt}, | ||
|
||
# ], | ||
# "parameters": {"max_new_tokens": 1000, "top_p": 0.9, "temperature": 0.6}, | ||
# } | ||
input_str = ''.join(prompt) | ||
input_str = json.dumps({"inputs": input_str, "parameters": model_kwargs}) | ||
print(input_str) | ||
# input_str = json.dumps( | ||
# payload, | ||
# ) | ||
input_utf = input_str | ||
print(input_utf) | ||
return input_utf | ||
|
||
def transform_output(self, output: bytes) -> str: | ||
response_json = json.loads(output.read().decode("utf-8")) | ||
content = response_json | ||
return content | ||
|
||
content_handler = ContentHandler() | ||
|
||
# # SageMaker Embeddings | ||
# sagemaker_embeddings = SagemakerEndpointEmbeddings( | ||
# endpoint_name=SAGEMAKER_ENDPOINT, | ||
# region_name=SAGEMAKER_REGION, | ||
# content_handler=content_handler, | ||
# ) | ||
|
||
# query_result = sagemaker_embeddings.embed_query("foo") | ||
|
||
|
||
def llm_prompt_run(user_context, question): | ||
|
||
prompt = ChatPromptTemplate( | ||
messages=[ | ||
SystemMessagePromptTemplate.from_template( | ||
"You are a friendly support rep at Kai Shoes. Use the following pieces of information to answer the user's question. If you don't know the answer, just say that you don't know, don't try to make up an answer." | ||
), | ||
MessagesPlaceholder(variable_name="chat_history"), | ||
HumanMessagePromptTemplate.from_template("{context}"), | ||
HumanMessagePromptTemplate.from_template("{question}") | ||
] | ||
) | ||
|
||
# SageMaker LLMChain | ||
llm = SagemakerEndpoint( | ||
endpoint_name=SAGEMAKER_ENDPOINT, | ||
region_name="us-west-2", | ||
model_kwargs={"max_new_tokens": 700, "top_p": 0.9, "temperature": 0.6}, | ||
endpoint_kwargs={"CustomAttributes": 'accept_eula=true'}, | ||
content_handler=content_handler, | ||
) | ||
|
||
chat_history = [] | ||
memory = ConversationBufferMemory(memory_key="chat_history",return_messages=True) | ||
|
||
chain = LLMChain(llm=llm, | ||
prompt=prompt, | ||
memory=memory, | ||
) | ||
|
||
|
||
llm_resp = chain.run({'context': user_context, 'question': question, 'chat_history': chat_history}) | ||
|
||
print(llm_resp) | ||
return llm_resp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
config==0.5.1 | ||
fastapi==0.103.2 | ||
loguru==0.7.2 | ||
pandas==1.5.3 | ||
|
||
Requests==2.31.0 | ||
langchain==0.0.313 | ||
PyMySQL==1.1.0 | ||
sqlalchemy | ||
sqlalchemy_utils | ||
|
||
uvicorn | ||
boto3 |