Skip to content

Commit

Permalink
Refactor AI Assistant to use OpenAI new Assistants API
Browse files Browse the repository at this point in the history
  • Loading branch information
kflim committed Nov 7, 2023
1 parent b83e6f6 commit 5ed5a12
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 107 deletions.
5 changes: 1 addition & 4 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,4 @@ S3_BUCKET_NAME=s3

# Github OAuth Provider
GITHUB_ID=EXAMPLE_GITHUB_ID
GITHUB_SECRET=EXAMPLE_GITHUB_SECRET

# OpenAI
OPENAI_API_KEY=EXAMPLE_OPENAI_API_KEY
GITHUB_SECRET=EXAMPLE_GITHUB_SECRET
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"next": "^13.4.13",
"next-auth": "^4.22.4",
"npm-run-all": "^4.1.5",
"openai": "^4.14.1",
"openai": "^4.16.1",
"react": "18.2.0",
"react-codemirror-merge": "^4.21.15",
"react-dom": "18.2.0",
Expand Down
10 changes: 5 additions & 5 deletions prisma/postgres/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ model User {
matchRequest MatchRequest?
joinRequest JoinRequest[]
sessionUserAndUserMessages SessionUserAndUserMessage[]
sessionUserAndAIMessages SessionUserAndAIMessage[]
sessionAIThreads SessionAIThread[]
Submission Submission[]
}

Expand Down Expand Up @@ -195,15 +195,15 @@ model SessionUserAndUserMessage {
@@index([sessionId])
}

model SessionUserAndAIMessage {
id String @id @default(cuid())
model SessionAIThread {
id String @id @default(cuid())
sessionId String
userId String
message String
role String
threadId String
createdAt DateTime @default(now())
user User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@unique([sessionId, userId, threadId])
@@index([sessionId])
}

Expand Down
6 changes: 4 additions & 2 deletions src/components/AIBox.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ const AIBox = ({
} text-white p-2 my-2`}
>
<div className="flex justify-between">
<span>{message.role === "user" ? userName : "GPT-3.5"}</span>
<span>
{message.role === "user" ? userName : "Code Assistant"}
</span>
</div>
<p>{message.message}</p>
</div>
Expand All @@ -54,7 +56,7 @@ const AIBox = ({
{isAIResponding && (
<input
className="w-full rounded-md p-2"
value="GPT-3.5 is responding..."
value="Code Assistant is responding..."
disabled
/>
)}
Expand Down
7 changes: 4 additions & 3 deletions src/env.mjs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ export const env = createEnv({
// Add `.min(1) on ID and SECRET if you want to make sure they're not empty
GITHUB_ID: z.string().min(1),
GITHUB_SECRET: z.string().min(1),
OPENAI_API_KEY: z.string().min(1),
},

/**
Expand All @@ -55,12 +54,14 @@ export const env = createEnv({
MONGO_URL: process.env.MONGO_URL,
NODE_ENV: process.env.NODE_ENV,
NEXTAUTH_SECRET: process.env.NEXTAUTH_SECRET,
NEXT_PUBLIC_WS_PORT: process.env.NODE_ENV === "production" ? process.env.NEXT_PUBLIC_WS_PORT : "3002",
NEXT_PUBLIC_WS_PORT:
process.env.NODE_ENV === "production"
? process.env.NEXT_PUBLIC_WS_PORT
: "3002",
NEXTAUTH_URL: process.env.NEXTAUTH_URL,
S3_BUCKET_NAME: process.env.S3_BUCKET_NAME,
GITHUB_ID: process.env.GITHUB_ID,
GITHUB_SECRET: process.env.GITHUB_SECRET,
OPENAI_API_KEY: process.env.OPENAI_API_KEY,
},
/**
* Run `build` or `dev` with `SKIP_ENV_VALIDATION` to skip env validation.
Expand Down
7 changes: 1 addition & 6 deletions src/hooks/useAIComm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,8 @@ export default function useAIComm(
if (allSessionMessages)
allSessionMessages.push({
...data,
id: data.id!,
message: data.message,
createdAt: data.createdAt!,
role: data.role,
});

setChatState((state) => ({
Expand All @@ -73,12 +72,8 @@ export default function useAIComm(
if (chatState.currentMessage.trim().length === 0) return;

allSessionMessages?.push({
id: "",
sessionId,
userId,
message: chatState.currentMessage,
role: "user",
createdAt: new Date(),
});

addMessageMutation.mutate({
Expand Down
7 changes: 5 additions & 2 deletions src/pages/collab/rooms/[id].tsx
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,10 @@ const Toolbar = ({
</select>
</label>
<label className="flex flex-row col-span-2">
<QuestionToggleModal questionTitleList={modifyQuestionProps.questionTitleList} setQuestionId={modifyQuestionProps.setQuestionId} />
<QuestionToggleModal
questionTitleList={modifyQuestionProps.questionTitleList}
setQuestionId={modifyQuestionProps.setQuestionId}
/>
</label>
<div className="flex flex-row col-span-2">
<label>
Expand Down Expand Up @@ -255,7 +258,7 @@ const Room = () => {
<TabList>
<Tab>Output</Tab>
<Tab>Chat</Tab>
<Tab>GPT-3.5</Tab>
<Tab>Code Assistant</Tab>
{useQuestionObject.submissionStatus && <Tab>Submission</Tab>}
</TabList>
<TabPanel>
Expand Down
167 changes: 104 additions & 63 deletions src/server/api/routers/userAndAIComm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,32 @@ import { EventEmitter } from "events";
import { z } from "zod";
import { createTRPCRouter, protectedProcedure } from "~/server/api/trpc";
import OpenAI from "openai";
import type { ChatCompletionRole } from "openai/resources";
import type { MessageContentText } from "openai/resources/beta/threads/messages/messages";
import { TRPCError } from "@trpc/server";

const ee = new EventEmitter();
const openai = new OpenAI({
apiKey: process.env.OPENAI_API_KEY,
});

const OPENAI_ASSISTANT_ID = "asst_1o7xscIo6R2jPBxVU5Boqmf2";
const openai = new OpenAI();

// Initialization for assistant
/*
async function main() {
const assistant = await openai.beta.assistants.create({
instructions: "You are an assistant that helps to explain code.",
name: "Code Assistant",
tools: [{ type: "code_interpreter" }],
model: "gpt-3.5-turbo-1106",
});
console.log(assistant);
} */

export type UserAndAIMessage = {
id?: string;
sessionId: string;
userId: string;
message: string;
role: ChatCompletionRole;
createdAt?: Date;
role: "user" | "assistant";
};

export const userAndAIMessagesRouter = createTRPCRouter({
Expand All @@ -35,18 +46,38 @@ export const userAndAIMessagesRouter = createTRPCRouter({
.query(async ({ ctx, input }) => {
const { sessionId, userId } = input;

const messages =
await ctx.prismaPostgres.sessionUserAndAIMessage.findMany({
where: {
let sessionThread = await ctx.prismaPostgres.sessionAIThread.findFirst({
where: {
sessionId,
userId,
},
});

if (!sessionThread) {
const newThread = await openai.beta.threads.create({});
sessionThread = await ctx.prismaPostgres.sessionAIThread.create({
data: {
sessionId,
userId,
},
orderBy: {
createdAt: "asc",
threadId: newThread.id,
},
});
}

const messages = await openai.beta.threads.messages.list(
sessionThread.threadId,
);

return messages;
return messages.data
.map((message) => {
const role = message.role;
const text = (message.content[0] as MessageContentText).text.value;
return {
message: text,
role,
};
})
.reverse(); // Messages from OpenAI Assistant are in reverse chronological order
}),

addUserAndAIMessage: protectedProcedure
Expand All @@ -60,63 +91,73 @@ export const userAndAIMessagesRouter = createTRPCRouter({
.mutation(async ({ ctx, input }) => {
const { sessionId, userId, message } = input;

const messageObject =
await ctx.prismaPostgres.sessionUserAndAIMessage.create({
data: {
sessionId,
userId,
message,
role: "user",
},
});
const sessionThread = await ctx.prismaPostgres.sessionAIThread.findFirst({
where: {
sessionId,
userId,
},
});

const currentSessionMessages =
await ctx.prismaPostgres.sessionUserAndAIMessage.findMany({
where: {
sessionId,
userId,
},
orderBy: {
createdAt: "asc",
},
});
// Add Message to the session Thread
await openai.beta.threads.messages.create(sessionThread?.threadId ?? "", {
role: "user",
content: message,
});

const response = await openai.chat.completions
.create({
messages: currentSessionMessages.map((message) => {
return {
role: message.role as ChatCompletionRole,
content: message.message,
};
}),
model: "gpt-3.5-turbo",
})
.catch((errorJsonObj) => {
// Retrieve the assistant
const openaiAssistant =
await openai.beta.assistants.retrieve(OPENAI_ASSISTANT_ID);

// Create the run for the assistant
const run = await openai.beta.threads.runs.create(
sessionThread?.threadId ?? "",
{
assistant_id: openaiAssistant.id,
instructions: "Please answer clearly and concisely",
},
);

// Wait for the run to complete
while (true) {
const response = await openai.beta.threads.runs.retrieve(
sessionThread?.threadId ?? "",
run.id,
);

if (response.status === "completed") break;
else if (response.status === "failed")
throw new TRPCError({
code: "TOO_MANY_REQUESTS",
message: errorJsonObj.error.message,
message: response.last_error?.message,
});
else if (response.status === "expired")
throw new TRPCError({
code: "TIMEOUT",
message: "Request timed out",
});
});

if (response) {
const aiMessage = response.choices[0]?.message;

if (aiMessage) {
const aiMessageObject =
await ctx.prismaPostgres.sessionUserAndAIMessage.create({
data: {
sessionId,
userId,
message: aiMessage.content!,
role: aiMessage.role,
},
});

ee.emit("aiMessage", aiMessageObject);
}
}

return messageObject;
const messages = await openai.beta.threads.messages.list(
sessionThread?.threadId ?? "",
);

const aiResponse = messages.data[0];
const id = aiResponse?.id;
const role = aiResponse?.role;
const text = (aiResponse?.content[0] as MessageContentText).text.value;

ee.emit("aiMessage", {
id,
sessionId,
userId,
message: text,
role,
});

return {
message: text,
role,
};
}),

subscribeToSessionUserAndAIMessages: protectedProcedure
Expand Down
Loading

0 comments on commit 5ed5a12

Please sign in to comment.