Skip to content

Commit

Permalink
api: support multiple input images
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale committed Dec 24, 2024
1 parent d4e78a4 commit 0f72141
Show file tree
Hide file tree
Showing 6 changed files with 487 additions and 161 deletions.
275 changes: 158 additions & 117 deletions aphrodite/endpoints/chat_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import asyncio
import codecs
import tempfile
from dataclasses import dataclass
from collections import defaultdict
from functools import lru_cache
from pathlib import Path
from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple,
Union, cast)
from typing import (Any, Awaitable, Dict, Iterable, List, Literal, Mapping,
Optional, Tuple, Union)

import requests
from loguru import logger
# yapf conflicts with isort for this block
# yapf: disable
Expand All @@ -18,9 +17,8 @@
ChatCompletionMessageParam as OpenAIChatCompletionMessageParam)
# yapf: enable
# pydantic needs the TypedDict from typing_extensions
from pydantic import ConfigDict
from transformers import PreTrainedTokenizer
from typing_extensions import Required, TypedDict
from pydantic import ConfigDict, TypeAdapter
from typing_extensions import Required, TypeAlias, TypedDict

from aphrodite.common.config import ModelConfig
from aphrodite.multimodal import MultiModalDataDict
Expand Down Expand Up @@ -50,9 +48,9 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False):
"""The type of the content part."""


ChatCompletionContentPartParam = Union[OpenAIChatCompletionContentPartParam,
ChatCompletionContentPartAudioParam,
CustomChatCompletionContentPartParam]
ChatCompletionContentPartParam: TypeAlias = Union[
OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam,
CustomChatCompletionContentPartParam, ]


class CustomChatCompletionMessageParam(TypedDict, total=False):
Expand All @@ -65,6 +63,7 @@ class CustomChatCompletionMessageParam(TypedDict, total=False):

name: str
"""An optional name for the participant.
Provides the model information to differentiate between participants of the
same role.
"""
Expand All @@ -80,184 +79,227 @@ class ConversationMessage(TypedDict):
content: str


@dataclass(frozen=True)
class ChatMessageParseResult:
messages: List[ConversationMessage]
mm_futures: List[Awaitable[MultiModalDataDict]]
class MultiModalItemTracker:
"""
Tracks multi-modal items in a given request and ensures that the number
of multi-modal items in a given request does not exceed the configured
maximum per prompt.
"""

def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer):
self._model_config = model_config
self._tokenizer = tokenizer
self._allowed_items = (model_config.multimodal_config.limit_per_prompt
if model_config.multimodal_config else {})
self._consumed_items = {k: 0 for k in self._allowed_items}
self._futures: List[Awaitable[MultiModalDataDict]] = []

@staticmethod
@lru_cache(maxsize=None)
def _cached_token_str(tokenizer: AnyTokenizer, token_index: int):
return tokenizer.decode(token_index)

def add(self, modality: Literal["image", "audio"],
mm_future: Awaitable[MultiModalDataDict]) -> Optional[str]:
"""
Adds the multi-modal item to the current prompt and returns the
placeholder string to use, if any.
"""
allowed_count = self._allowed_items.get(modality, 1)
current_count = self._consumed_items.get(modality, 0) + 1
if current_count > allowed_count:
raise ValueError(
f"At most {allowed_count} {modality}(s) may be provided in "
"one request.")

self._consumed_items[modality] = current_count
self._futures.append(mm_future)

# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = self._model_config.hf_config.model_type
if modality == "image":
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return f"<|image_{current_count}|>"
if model_type == "minicpmv":
return "(<image>./</image>)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
# These models do not use image tokens in the prompt
return None
if model_type.startswith("llava"):
return MultiModalItemTracker._cached_token_str(
self._tokenizer,
self._model_config.hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"):
return "<image>"

raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio":
if model_type == "ultravox":
return "<|reserved_special_token_0|>"
raise TypeError(f"Unknown model type: {model_type}")
else:
raise TypeError(f"Unknown modality: {modality}")

@staticmethod
async def _combine(futures: List[Awaitable[MultiModalDataDict]]):
mm_lists: Mapping[str, List[object]] = defaultdict(list)

# Merge all the multi-modal items
for single_mm_data in (await asyncio.gather(*futures)):
for mm_key, mm_item in single_mm_data.items():
if isinstance(mm_item, list):
mm_lists[mm_key].extend(mm_item)
else:
mm_lists[mm_key].append(mm_item)

# Unpack any single item lists for models that don't expect multiple.
return {
mm_key: mm_list[0] if len(mm_list) == 1 else mm_list
for mm_key, mm_list in mm_lists.items()
}

def all_mm_data(self) -> Optional[Awaitable[MultiModalDataDict]]:
return MultiModalItemTracker._combine(
self._futures) if self._futures else None


def load_chat_template(
chat_template: Optional[Union[Path, str]]) -> Optional[str]:
if chat_template is None:
return None
try:
chat_template_str = str(chat_template)
if chat_template_str.startswith(('http')):
response = requests.get(chat_template_str)
temp = tempfile.NamedTemporaryFile(delete=False)
temp.write(response.content)
temp.close()
chat_template = temp.name

with open(chat_template, "r") as f:
resolved_chat_template = f.read()
except OSError as e:
if isinstance(chat_template, Path):
raise

JINJA_CHARS = "{}\n"
if not any(c in chat_template for c in JINJA_CHARS):
msg = (f"The supplied chat template ({chat_template}) "
"looks like a file path, but it failed to be "
f"looks like a file path, but it failed to be "
f"opened. Reason: {e}")
raise ValueError(msg) from e

# If opening a file fails, set chat template to be args to
# ensure we decode so our escape are interpreted correctly
resolved_chat_template = codecs.decode(chat_template, "unicode_escape")

logger.info(f"Using supplied chat template:\n{resolved_chat_template}")
logger.info("Using supplied chat template:\n%s", resolved_chat_template)
return resolved_chat_template


@lru_cache(maxsize=None)
def _mm_token_str(model_config: ModelConfig, tokenizer: PreTrainedTokenizer,
modality: Literal["image", "audio"]) -> Optional[str]:
# TODO: Let user specify how to insert image tokens into prompt
# (similar to chat template)
model_type = model_config.hf_config.model_type
if modality == "image":
if model_type == "phi3_v":
# Workaround since this token is not defined in the tokenizer
return "<|image_1|>"
if model_type == "minicpmv":
return "(<image>./</image>)"
if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"):
# These models do not use image tokens in the prompt
return None
if model_type.startswith("llava"):
return tokenizer.decode(model_config.hf_config.image_token_index)
if model_type in ("chameleon", "internvl_chat"):
return "<image>"

raise TypeError(f"Unknown model type: {model_type}")
elif modality == "audio":
if model_type == "ultravox":
return "<|reserved_special_token_0|>"
raise TypeError(f"Unknown model type: {model_type}")
else:
raise TypeError(f"Unknown modality: {modality}")


# TODO: Let user specify how to insert multimodal tokens into prompt
# (similar to chat template)
def _get_full_multimodal_text_prompt(placeholder_token_str: str,
def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int],
text_prompt: str) -> str:
"""Combine multimodal prompts for a multimodal language model"""

# NOTE: For now we assume all model architectures use the same
# placeholder + text prompt format. This may change in the future.
return f"{placeholder_token_str}\n{text_prompt}"
# Look through the text prompt to check for missing placeholders
missing_placeholders = []
for placeholder in placeholder_counts:

# For any existing placeholder in the text prompt, we leave it as is
placeholder_counts[placeholder] -= text_prompt.count(placeholder)

if placeholder_counts[placeholder] < 0:
raise ValueError(
f"Found more '{placeholder}' placeholders in input prompt than "
"actual multimodal data items.")

missing_placeholders.extend([placeholder] *
placeholder_counts[placeholder])

# NOTE: For now we always add missing placeholders at the front of
# the prompt. This may change to be customizable in the future.
return "\n".join(missing_placeholders + [text_prompt])


_TextParser = TypeAdapter(ChatCompletionContentPartTextParam)
_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam)
_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam)


def _parse_chat_message_content_parts(
role: str,
parts: Iterable[ChatCompletionContentPartParam],
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
) -> ChatMessageParseResult:
mm_tracker: MultiModalItemTracker,
) -> List[ConversationMessage]:
texts: List[str] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
modality: Literal["image", "audio"] = "image"

# multimodal placeholder_string : count
mm_placeholder_counts: Dict[str, int] = {}

for part in parts:
part_type = part["type"]
if part_type == "text":
text = cast(ChatCompletionContentPartTextParam, part)["text"]
text = _TextParser.validate_python(part)["text"]
texts.append(text)
elif part_type == "image_url":
modality = "image"
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.")

image_url = cast(ChatCompletionContentPartImageParam,
part)["image_url"]
image_url = _ImageParser.validate_python(part)["image_url"]

if image_url.get("detail", "auto") != "auto":
logger.warning(
"'image_url.detail' is currently not supported and "
"will be ignored.")

image_future = async_get_and_parse_image(image_url["url"])
mm_futures.append(image_future)
image_coro = async_get_and_parse_image(image_url["url"])
placeholder = mm_tracker.add("image", image_coro)
if placeholder:
mm_placeholder_counts[placeholder] = mm_placeholder_counts.get(
placeholder, 0) + 1
elif part_type == "audio_url":
modality = "audio"
if len(mm_futures) > 0:
raise NotImplementedError(
"Multiple multimodal inputs is currently not supported.")

audio_url = cast(ChatCompletionContentPartAudioParam,
part)["audio_url"]
audio_future = async_get_and_parse_audio(audio_url["url"])
mm_futures.append(audio_future)
audio_url = _AudioParser.validate_python(part)["audio_url"]
audio_coro = async_get_and_parse_audio(audio_url["url"])
placeholder = mm_tracker.add("audio", audio_coro)
if placeholder:
mm_placeholder_counts[placeholder] = mm_placeholder_counts.get(
placeholder, 0) + 1
else:
raise NotImplementedError(f"Unknown part type: {part_type}")

text_prompt = "\n".join(texts)
if mm_placeholder_counts:
text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts,
text_prompt)

if mm_futures:
placeholder_token_str = _mm_token_str(model_config, tokenizer,
modality)
if placeholder_token_str is not None:
if placeholder_token_str in text_prompt:
logger.warning(
"Detected multi-modal token string in the text prompt. "
"Skipping prompt formatting.")
else:
text_prompt = _get_full_multimodal_text_prompt(
placeholder_token_str=placeholder_token_str,
text_prompt=text_prompt,
)

messages = [ConversationMessage(role=role, content=text_prompt)]

return ChatMessageParseResult(messages=messages, mm_futures=mm_futures)
return [ConversationMessage(role=role, content=text_prompt)]


def _parse_chat_message_content(
message: ChatCompletionMessageParam,
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
) -> ChatMessageParseResult:
message: ChatCompletionMessageParam,
mm_tracker: MultiModalItemTracker) -> List[ConversationMessage]:
role = message["role"]
content = message.get("content")

if content is None:
return ChatMessageParseResult(messages=[], mm_futures=[])
return []
if isinstance(content, str):
messages = [ConversationMessage(role=role, content=content)]
return ChatMessageParseResult(messages=messages, mm_futures=[])
return [ConversationMessage(role=role, content=content)]

return _parse_chat_message_content_parts(role, content, model_config,
tokenizer)
return _parse_chat_message_content_parts(
role,
content, # type: ignore
mm_tracker,
)


def parse_chat_messages(
messages: List[ChatCompletionMessageParam],
model_config: ModelConfig,
tokenizer: PreTrainedTokenizer,
) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]:
tokenizer: AnyTokenizer,
) -> Tuple[List[ConversationMessage], Optional[Awaitable[MultiModalDataDict]]]:
conversation: List[ConversationMessage] = []
mm_futures: List[Awaitable[MultiModalDataDict]] = []
mm_tracker = MultiModalItemTracker(model_config, tokenizer)

for msg in messages:
parse_result = _parse_chat_message_content(msg, model_config,
tokenizer)
sub_messages = _parse_chat_message_content(msg, mm_tracker)

conversation.extend(parse_result.messages)
mm_futures.extend(parse_result.mm_futures)
conversation.extend(sub_messages)

return conversation, mm_futures
return conversation, mm_tracker.all_mm_data()


def apply_chat_template(
Expand All @@ -280,5 +322,4 @@ def apply_chat_template(
tokenize=tokenize,
**kwargs,
)

return prompt
Loading

0 comments on commit 0f72141

Please sign in to comment.