From b7348446990ace4b86a3440c45e27832238c05ab Mon Sep 17 00:00:00 2001 From: Wes Kennedy Date: Tue, 17 Oct 2023 08:49:49 -0700 Subject: [PATCH] lets try this --- services/api/.gitignore | 1 + services/api/Dockerfile | 21 ++++ services/api/README.md | 2 + services/api/build_locally.sh | 7 ++ services/api/contextualize.py | 41 ++++++++ services/api/db.py | 54 ++++++++++ services/api/main.py | 181 ++++++++++++++++++++++++++++++++++ services/api/requirements.txt | 13 +++ 8 files changed, 320 insertions(+) create mode 100644 services/api/.gitignore create mode 100644 services/api/Dockerfile create mode 100644 services/api/README.md create mode 100755 services/api/build_locally.sh create mode 100644 services/api/contextualize.py create mode 100644 services/api/db.py create mode 100644 services/api/main.py create mode 100644 services/api/requirements.txt diff --git a/services/api/.gitignore b/services/api/.gitignore new file mode 100644 index 0000000..2eea525 --- /dev/null +++ b/services/api/.gitignore @@ -0,0 +1 @@ +.env \ No newline at end of file diff --git a/services/api/Dockerfile b/services/api/Dockerfile new file mode 100644 index 0000000..417afd5 --- /dev/null +++ b/services/api/Dockerfile @@ -0,0 +1,21 @@ +# a dockerfile for a fastapi server + +# Use the official Python image. +# https://hub.docker.com/_/python +FROM python:3.8-slim + +# Copy local code to the container image. +ENV APP_HOME /app +WORKDIR $APP_HOME +ARG COMMITHASH +ENV COMMITHASH=$COMMITHASH +COPY requirements.txt ./ + +# Install production dependencies. +RUN pip install --upgrade pip +RUN pip install -r requirements.txt + +COPY . ./ + +# Run the fastapi service +CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "4000"] \ No newline at end of file diff --git a/services/api/README.md b/services/api/README.md new file mode 100644 index 0000000..2afb634 --- /dev/null +++ b/services/api/README.md @@ -0,0 +1,2 @@ +# Private LLM API + diff --git a/services/api/build_locally.sh b/services/api/build_locally.sh new file mode 100755 index 0000000..3a7c6dc --- /dev/null +++ b/services/api/build_locally.sh @@ -0,0 +1,7 @@ +#!/bin/sh +source ./.env +REPOSITORY=private-llm-fastapi-server +IMAGE_TAG=latest +docker build --build-arg="COMMITHASH=localtest" -t $REPOSITORY:$IMAGE_TAG . + +docker run --rm -p 4001:4000 --env-file ./.env $REPOSITORY:$IMAGE_TAG \ No newline at end of file diff --git a/services/api/contextualize.py b/services/api/contextualize.py new file mode 100644 index 0000000..a69de4b --- /dev/null +++ b/services/api/contextualize.py @@ -0,0 +1,41 @@ +""" +contextualize.py + +Maintainer: Wes Kennedy +Description: The contextualize module allows us to write app specific queries to help build our database. +""" + +import db +from sqlalchemy import * + +class Contextualizer(): + def __init__(): + pass + + def customer_lookup_byid(customer_id): + """ + Takes a customer id and returns the customer's name + """ + pass + + def customer_lookup_byname(customer_name): + """ + Takes a customer name and returns the customer's id + """ + pass + + def customer_lookup_byemail(customer_email): + """ + Takes a customer email and returns the customer's id + """ + pass + + def customer_previous_orders(db_conn, customer_id): + """ + Takes a customer id and returns a list of previous orders + """ + orders = [] + query = text(f"SELECT * FROM orders WHERE customer_id = {customer_id}") + response = db.query_wrapper(db_conn, query) + + return orders \ No newline at end of file diff --git a/services/api/db.py b/services/api/db.py new file mode 100644 index 0000000..9eb405f --- /dev/null +++ b/services/api/db.py @@ -0,0 +1,54 @@ +import os +# add the ability to query a mysql server using sql_alchemy +from sqlalchemy import * +from sqlalchemy_utils import database_exists + +MYSQL_HOST = os.getenv("MYSQL_HOST") +MYSQL_PORT = os.getenv("MYSQL_PORT") +MYSQL_USER = os.getenv("MYSQL_USER") +MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD") +MYSQL_DATABASE = os.getenv("MYSQL_DATABASE") + + + +def init_connect(): + # Create the connection string + connection_string = f"mysql+pymysql://{MYSQL_USER}:{MYSQL_PASSWORD}@{MYSQL_HOST}:{MYSQL_PORT}" + full_engine = f"{connection_string}/{MYSQL_DATABASE}" + engine = create_engine(connection_string) + if database_exists(full_engine): + return engine + else: + with engine.connect() as conn: + init_db(conn) + return engine + + # Create the engine +def simple_connect(): + pass + +def init(): + db_conn = init_connect() + use_db(db_conn) + init_chat_table(db_conn) + return db_conn + +def init_db(db_conn): + query_create_db = text(f"CREATE DATABASE {MYSQL_DATABASE}") #create db + with db_conn.connect() as conn: + conn.execute(query_create_db) + +def use_db(db_conn): + use_db = text(f"USE {MYSQL_DATABASE}") + with db_conn.connect() as conn: + conn.execute(use_db) + +def init_chat_table(db_conn): + create_chat_table = text(f"CREATE TABLE IF NOT EXISTS messages (_id INT AUTO_INCREMENT,conversation_id CHAR(255),message TEXT,sender VARCHAR(50),timestamp TIMESTAMP,chat_context JSON,user_context TEXT,embedding BLOB NOT NULL, PRIMARY KEY (_id));") + with db_conn.connect() as conn: + conn.execute(create_chat_table) + +def query_wrapper(db_conn, query): + with db_conn.connect() as conn: + result = conn.execute(query) + return result diff --git a/services/api/main.py b/services/api/main.py new file mode 100644 index 0000000..01498df --- /dev/null +++ b/services/api/main.py @@ -0,0 +1,181 @@ +# Create a FastAPI server and define the endpoints: /embedding, /chat as POST requests +# /embedding: takes a text input and returns the embedding from a remote api call using requests +# /chat: takes a text input and context and returns a response from a remote api call using requests +# Note: the remote api calls are defined in the config file + +import os +import sys +import json +import pandas as pd +from typing import List, Optional, Dict +from fastapi import FastAPI, HTTPException, Request, Response +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel +from langchain.prompts import ( + ChatPromptTemplate, + MessagesPlaceholder, + SystemMessagePromptTemplate, + HumanMessagePromptTemplate, +) +from langchain.embeddings import SagemakerEndpointEmbeddings +from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler +from langchain.llms import SagemakerEndpoint +from langchain.chains import RetrievalQA +from langchain.llms.sagemaker_endpoint import LLMContentHandler +from contextualize import Contextualizer +from langchain.chains import LLMChain +from langchain.memory import ConversationSummaryMemory +import db +from langchain.memory import ConversationBufferMemory + +memory = ConversationBufferMemory() + +db_conn = db.init() + +### System Prompt +system_prompt = """ + You are a helpful customer service agent working for Kai Shoes. \n + You will be chatting with a customer. \n + Use context from their previous orders to help them make decisions. + """ + +# Add the parent directory to the path to import the config file +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +SAGEMAKER_ENDPOINT = os.getenv("SAGEMAKER_ENDPOINT") +SAGEMAKER_ROLE = os.getenv("SAGEMAKER_ROLE") +SAGEMAKER_REGION = os.getenv("SAGEMAKER_REGION") +AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID") +AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY") +COMMITHASH = os.getenv("COMMITHASH") + +print(COMMITHASH) + +print(AWS_ACCESS_KEY_ID) +print(AWS_SECRET_ACCESS_KEY) + + +# Initialize the FastAPI server +app = FastAPI( + title="LLM API", + description="API for the LLM project", + version="0.1.0", + docs_url="/", +) + +## FastAPI Routes +### /chat route + +""" +Expected JSON: + +{ + "text": "this is my message", + "cust_id": "1234" +} +""" +@app.post("/chat") +async def chat(request: Request): + """ + Takes a text input and context and returns a response from a remote api call using requests + """ + # Get the request body + body = await request.json() + # Get the text input + print(body) + question = body.get("text") + + # Get the context + context = body.get("cust_id") + + + # Get the response from the LLMChain + response = llm_prompt_run(context, question) + # Return the response + return {"response": response} + +@app.get("/test") +async def root(): + return {"message": "Hello World, I'm runnin on commit {}".format(COMMITHASH)} + + + +# SageMaker Endpoint Handler +class ContentHandler(LLMContentHandler): + content_type = "application/json" + accepts = "application/json" + + def transform_input(self, prompt: str, model_kwargs: dict) -> bytes: + # payload = { + # "inputs": [ + # { + # "role": "system", + # "content": system_prompt, + # }, + # {"role": "user", "content": prompt}, + + # ], + # "parameters": {"max_new_tokens": 1000, "top_p": 0.9, "temperature": 0.6}, + # } + input_str = ''.join(prompt) + input_str = json.dumps({"inputs": input_str, "parameters": model_kwargs}) + print(input_str) + # input_str = json.dumps( + # payload, + # ) + input_utf = input_str + print(input_utf) + return input_utf + + def transform_output(self, output: bytes) -> str: + response_json = json.loads(output.read().decode("utf-8")) + content = response_json + return content + +content_handler = ContentHandler() + +# # SageMaker Embeddings +# sagemaker_embeddings = SagemakerEndpointEmbeddings( +# endpoint_name=SAGEMAKER_ENDPOINT, +# region_name=SAGEMAKER_REGION, +# content_handler=content_handler, +# ) + +# query_result = sagemaker_embeddings.embed_query("foo") + + +def llm_prompt_run(user_context, question): + + prompt = ChatPromptTemplate( + messages=[ + SystemMessagePromptTemplate.from_template( + "You are a friendly support rep at Kai Shoes. Use the following pieces of information to answer the user's question. If you don't know the answer, just say that you don't know, don't try to make up an answer." + ), + MessagesPlaceholder(variable_name="chat_history"), + HumanMessagePromptTemplate.from_template("{context}"), + HumanMessagePromptTemplate.from_template("{question}") + ] + ) + + # SageMaker LLMChain + llm = SagemakerEndpoint( + endpoint_name=SAGEMAKER_ENDPOINT, + region_name="us-west-2", + model_kwargs={"max_new_tokens": 700, "top_p": 0.9, "temperature": 0.6}, + endpoint_kwargs={"CustomAttributes": 'accept_eula=true'}, + content_handler=content_handler, + ) + + chat_history = [] + memory = ConversationBufferMemory(memory_key="chat_history",return_messages=True) + + chain = LLMChain(llm=llm, + prompt=prompt, + memory=memory, + ) + + + llm_resp = chain.run({'context': user_context, 'question': question, 'chat_history': chat_history}) + + print(llm_resp) + return llm_resp \ No newline at end of file diff --git a/services/api/requirements.txt b/services/api/requirements.txt new file mode 100644 index 0000000..f98b00a --- /dev/null +++ b/services/api/requirements.txt @@ -0,0 +1,13 @@ +config==0.5.1 +fastapi==0.103.2 +loguru==0.7.2 +pandas==1.5.3 + +Requests==2.31.0 +langchain==0.0.313 +PyMySQL==1.1.0 +sqlalchemy +sqlalchemy_utils + +uvicorn +boto3 \ No newline at end of file