Skip to content

Commit

Permalink
Support formatting image input in Conversation.getPromptArray
Browse files Browse the repository at this point in the history
- getPromptArray element can now be an array itself
- Conversation.messages message can either be a string or an array of content parts
- Update compareConversationObject
- Next step is to update llm_chat.ts getInputTokens
  • Loading branch information
CharlieFRuan committed Sep 20, 2024
1 parent fad3df9 commit 19c9991
Show file tree
Hide file tree
Showing 8 changed files with 648 additions and 176 deletions.
239 changes: 185 additions & 54 deletions src/conversation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import {
Role,
} from "./config";
import {
ChatCompletionContentPart,
ChatCompletionContentPartImage,
ChatCompletionMessageParam,
ChatCompletionRequest,
} from "./openai_api_protocols/index";
Expand All @@ -13,6 +15,7 @@ import {
FunctionNotFoundError,
InvalidToolChoiceError,
MessageOrderError,
MultipleTextContentError,
SystemMessageOrderError,
TextCompletionConversationError,
TextCompletionConversationExpectsPrompt,
Expand All @@ -21,12 +24,19 @@ import {
UnsupportedToolTypeError,
} from "./error";

type ImageURL = ChatCompletionContentPartImage.ImageURL;

/**
* Helper to keep track of history conversations.
*/
export class Conversation {
// NOTE: Update `compareConversationObject()` whenever a new state is introduced.
public messages: Array<[Role, string, string | undefined]> = [];
/** Each message is a tuple of (Role, role_name_str, message), where message can be either a
* string or an array of contentPart for possible image input.
*/
public messages: Array<
[Role, string, string | Array<ChatCompletionContentPart> | undefined]
> = [];
readonly config: ConvTemplateConfig;

/** Whether the Conversation object is for text completion with no conversation-style formatting */
Expand All @@ -46,13 +56,17 @@ export class Conversation {
this.isTextCompletion = isTextCompletion;
}

private getPromptArrayInternal(addSystem: boolean, startPos: number) {
// TODO: Consider rewriting this method, a bit messy.
private getPromptArrayInternal(
addSystem: boolean,
startPos: number,
): Array<string | Array<string | ImageURL>> {
if (this.config.seps.length == 0) {
throw Error("Need seps to work");
}

// Prepare system message
// Get overrided system message if exists, else use default one in config
// Get overridden system message if exists, else use default one in config
let system_message = this.config.system_message;
if (this.override_system_message !== undefined) {
system_message = this.override_system_message;
Expand All @@ -61,57 +75,113 @@ export class Conversation {
MessagePlaceholders.system,
system_message,
);
const ret = addSystem ? [system_prompt] : [];
const ret: Array<string | Array<string | ImageURL>> = addSystem
? [system_prompt]
: [];

// Process each message in this.messages
for (let i = startPos; i < this.messages.length; ++i) {
const item = this.messages[i];
const role = item[0];
const role_str = item[1];
const message = item[2];

if (message !== undefined && message != "") {
let message_str;
if (this.config.role_templates !== undefined) {
message_str = this.config.role_templates[role]?.replace(
MessagePlaceholders[Role[role] as keyof typeof MessagePlaceholders],
message,
const messageContent = item[2];

// 1. Message from `appendReplyHeader()`, message is empty; not much processing is needed.
if (messageContent === undefined) {
if (i !== this.messages.length - 1) {
throw new Error(
"InternalError: Only expect message to be undefined for last " +
"message for a reply header.",
);
if (this.use_function_calling && this.function_string !== "") {
message_str = message_str?.replace(
MessagePlaceholders.function,
this.function_string,
);
}
message_str = message_str?.replace(MessagePlaceholders.function, "");
}
const empty_sep = this.config.role_empty_sep
? this.config.role_empty_sep
: ": ";
ret.push(role_str + empty_sep);
continue;
}

if (message_str == undefined) {
message_str = message;
// 2. Each messageContent consists of one textPart, and >= 0 imageParts, regardless whether
// it is Array<ChatCompletionContentPart> or text message. So we extract out each.
let textContentPart = ""; // if no textPart, use an empty string
const imageContentParts: ImageURL[] = [];
if (Array.isArray(messageContent)) {
// 2.1 content is Array<ChatCompletionContentPart>
// Iterate through the contentParts, get the text and list of images. There should
// be only a single text. TODO: is it always the case the number of textContentPart <= 1?
let seenText = false;
for (let i = 0; i < messageContent.length; i++) {
const curContentPart = messageContent[i];
if (curContentPart.type === "text") {
if (seenText) {
throw new MultipleTextContentError();
}
textContentPart = curContentPart.text;
seenText = true;
} else {
imageContentParts.push(curContentPart.image_url);
}
}
let role_prefix;
if (
this.config.add_role_after_system_message === false &&
system_prompt != "" &&
i == 0
) {
role_prefix = "";
} else {
const content_sep = this.config.role_content_sep
? this.config.role_content_sep
: ": ";
role_prefix = role_str + content_sep;
} else {
// 2.2 content is just a string
textContentPart = messageContent;
}

// 3. Format textContentPart with role and sep to get message_str and role_prefix
let message_str;
let role_prefix;
if (this.config.role_templates !== undefined) {
message_str = this.config.role_templates[role]?.replace(
MessagePlaceholders[Role[role] as keyof typeof MessagePlaceholders],
textContentPart,
);
if (this.use_function_calling && this.function_string !== "") {
message_str = message_str?.replace(
MessagePlaceholders.function,
this.function_string,
);
}
message_str = message_str?.replace(MessagePlaceholders.function, "");
}

if (message_str == undefined) {
message_str = textContentPart;
}
if (
this.config.add_role_after_system_message === false &&
system_prompt != "" &&
i == 0
) {
role_prefix = "";
} else {
const content_sep = this.config.role_content_sep
? this.config.role_content_sep
: ": ";
role_prefix = role_str + content_sep;
}

// 4. Combine everything together
if (imageContentParts.length === 0) {
// If no image, just a single string to represent this message
ret.push(
role_prefix +
message_str +
this.config.seps[i % this.config.seps.length],
);
} else {
const empty_sep = this.config.role_empty_sep
? this.config.role_empty_sep
: ": ";
ret.push(role_str + empty_sep);
// If has image input, currently we hard code it to Phi3.5-vision's format:
// `<|user|>\n<|image_1|>\n<|image_2|>\n{prompt}<|end|>\n`
// So we will return a list for this:
// [`<|user|>\n`, imageUrl1, `\n`, imageUrl2, `\n`, `{prompt}<|end|>\n`]
const curMessageList: Array<string | ImageURL> = [role_prefix];
imageContentParts.forEach((curImage: ImageURL) => {
curMessageList.push(curImage);
curMessageList.push("\n");
});
curMessageList.push(
message_str + this.config.seps[i % this.config.seps.length],
);
ret.push(curMessageList);
}
}
return ret;
Expand All @@ -120,9 +190,26 @@ export class Conversation {
/**
* Get prompt arrays with the first one as system.
*
* It is returned as an array of `string | Array<string | ImageURL>`, where each element of
* the array represents the formatted message of a role/turn. If the message only contains text,
* it will be a string that concatenates the role string, message, and separators. If the
* message contains image(s), it will be an array of string and ImageURL in the order of which
* they will be prefilled into the model. e.g. it can be something like
* [
* "<|system|>\nSome system prompt\n",
* [
* "<|user|>\n",
* imageURL1,
* "\n",
* imageURL2,
* "\n",
* "Some user input<|end|>\n"
* ],
* ]
*
* @returns The prompt array.
*/
getPromptArray(): Array<string> {
getPromptArray(): Array<string | Array<string | ImageURL>> {
if (this.isTextCompletion) {
throw new TextCompletionConversationError("getPromptArray");
}
Expand Down Expand Up @@ -181,7 +268,11 @@ export class Conversation {
return this.config.stop_token_ids;
}

appendMessage(role: Role, message: string, role_name?: string) {
appendMessage(
role: Role,
message: string | Array<ChatCompletionContentPart>,
role_name?: string,
) {
if (this.isTextCompletion) {
throw new TextCompletionConversationError("appendMessage");
}
Expand Down Expand Up @@ -263,12 +354,57 @@ export function compareConversationObject(
// both are empty
return true;
}
if (convA.messages.length !== convB.messages.length) {
// different number of messages
return false;
}

const msgLen = convA.messages.length;
const msgEntryLen = convA.messages[0].length;
const msgEntryLen = convA.messages[0].length; // always 3 for now
for (let i = 0; i < msgLen; i++) {
for (let j = 0; j < msgEntryLen; j++) {
if (convA.messages[i][j] !== convB.messages[i][j]) {
const entryA = convA.messages[i][j];
const entryB = convB.messages[i][j];
if (typeof entryA === "string" && typeof entryB === "string") {
// Case 1: both are strings
if (convA.messages[i][j] !== convB.messages[i][j]) {
return false;
}
} else if (entryA === undefined && entryB === undefined) {
// Case 2: both undefined
continue;
} else if (Array.isArray(entryA) && Array.isArray(entryB)) {
// Case 3: both are ChatCompletionContentPart[]
if (entryA.length !== entryB.length) {
return false;
}
const numContentParts = entryA.length;
for (let k = 0; k < numContentParts; k++) {
const entryA_k = entryA[k];
const entryB_k = entryB[k];
if (entryA_k.type === "text" && entryB_k.type === "text") {
// Case 3.1: both are text
if (entryA_k.text !== entryB_k.text) {
return false;
}
} else if (
entryA_k.type === "image_url" &&
entryB_k.type === "image_url"
) {
// Case 3.2: both are image_url
if (
entryA_k.image_url.url !== entryB_k.image_url.url ||
entryA_k.image_url.detail !== entryB_k.image_url.detail
) {
return false;
}
} else {
// Case 3.3: of different type
return false;
}
}
} else {
// Case 4: two entries are of different types
return false;
}
}
Expand All @@ -280,12 +416,14 @@ export function compareConversationObject(
* Get a new Conversation object based on the chat completion request.
*
* @param request The incoming ChatCompletionRequest
* @note `request.messages[-1]` is not included as it would be treated as a normal input to
* `prefill()`.
* @param includeLastMsg Include last message, by default is false. Set to true for testing only.
* @note By default, `request.messages[-1]` is not included as it would be treated as a normal
* input to `prefill()`.
*/
export function getConversationFromChatCompletionRequest(
request: ChatCompletionRequest,
config: ChatConfig,
includeLastMsg = false,
): Conversation {
// 0. Instantiate a new Conversation object
const conversation = getConversation(
Expand All @@ -304,27 +442,20 @@ export function getConversationFromChatCompletionRequest(
// 2. Populate conversation.messages
const input = request.messages;
const lastId = input.length - 1;
if (
(input[lastId].role !== "user" && input[lastId].role !== "tool") ||
typeof input[lastId].content !== "string"
) {
// TODO(Charlie): modify condition after we support multimodal inputs
if (input[lastId].role !== "user" && input[lastId].role !== "tool") {
throw new MessageOrderError(
"The last message should be a string from the `user` or `tool`.",
"The last message should be from the `user` or `tool`.",
);
}
for (let i = 0; i < input.length - 1; i++) {
const iterEnd = includeLastMsg ? input.length : input.length - 1;
for (let i = 0; i < iterEnd; i++) {
const message: ChatCompletionMessageParam = input[i];
if (message.role === "system") {
if (i !== 0) {
throw new SystemMessageOrderError();
}
conversation.override_system_message = message.content;
} else if (message.role === "user") {
if (typeof message.content !== "string") {
// TODO(Charlie): modify condition after we support multimodal inputs
throw new ContentTypeError(message.role + "'s message");
}
conversation.appendMessage(Role.user, message.content, message.name);
} else if (message.role === "assistant") {
if (typeof message.content !== "string") {
Expand Down
4 changes: 2 additions & 2 deletions src/error.ts
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,9 @@ export class UnsupportedDetailError extends Error {
}

export class MultipleTextContentError extends Error {
constructor(numTextContent: number) {
constructor() {
super(
`Each message can have at most one text contentPart, but received: ${numTextContent}`,
`Each message can have at most one text contentPart, but received more than 1.`,
);
this.name = "MultipleTextContentError";
}
Expand Down
Loading

0 comments on commit 19c9991

Please sign in to comment.