Skip to content

Commit

Permalink
Squash 4645
Browse files Browse the repository at this point in the history
  • Loading branch information
prashantgupta24 committed Jul 1, 2024
1 parent c544ecf commit 0558bcc
Show file tree
Hide file tree
Showing 43 changed files with 1,840 additions and 522 deletions.
1 change: 1 addition & 0 deletions format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ mypy vllm/spec_decode --config-file pyproject.toml
mypy vllm/model_executor --config-file pyproject.toml
mypy vllm/lora --config-file pyproject.toml
mypy vllm/logging --config-file pyproject.toml
mypy vllm/prompt_adapter --config-file pyproject.toml
mypy tests --config-file pyproject.toml


Expand Down
9 changes: 4 additions & 5 deletions tests/lora/test_long_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,10 @@ def batched_generate(
for input in inputs:
prompt, sampling_param, lora_req = input
# Add requests to the engine and run the engine
llm._validate_and_add_requests(
prompt,
sampling_param,
lora_request=lora_req,
)
llm._validate_and_add_requests(prompt,
sampling_param,
lora_request=lora_req,
prompt_adapter_request=None)

outputs = llm._run_engine(use_tqdm=True)
return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))]
Expand Down
326 changes: 164 additions & 162 deletions tests/lora/test_lora_manager.py

Large diffs are not rendered by default.

39 changes: 39 additions & 0 deletions tests/prompt_adapter/test_bloom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import vllm
from vllm.prompt_adapter.request import PromptAdapterRequest

MODEL_PATH = "bigscience/bloomz-560m"
PA_PATH = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'


def do_sample(llm, pa_name: str, pa_id: int):

prompts = [
"Tweet text : @nationalgridus I have no water and the bill is \
current and paid. Can you do something about this? Label : ",
"Tweet text : @nationalgridus Looks good thanks! Label : "
]
sampling_params = vllm.SamplingParams(temperature=0.0,
max_tokens=3,
stop_token_ids=[3])

outputs = llm.generate(prompts,
sampling_params,
prompt_adapter_request=PromptAdapterRequest(
pa_name, pa_id, PA_PATH, 8) if pa_id else None)

# Print the outputs.
generated_texts = []
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text.strip()
generated_texts.append(generated_text)
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
return generated_texts


def test_twitter_prompt_adapter():
llm = vllm.LLM(MODEL_PATH, enable_prompt_adapter=True)

expected_output = ['complaint', 'no complaint']

assert do_sample(llm, "twitter_pa", pa_id=1) == expected_output
52 changes: 52 additions & 0 deletions tests/prompt_adapter/test_multi_adapter_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.prompt_adapter.request import PromptAdapterRequest

MODEL_PATH = "bigscience/bloomz-560m"
pa_path = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM'
pa_path2 = 'swapnilbp/angry_tweet_ptune'


def do_sample(engine):

prompts = [
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech", 1, pa_path2, 8)),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)),
("Tweet text: I have complaints! Label: ",
SamplingParams(temperature=0.0, max_tokens=3), None),
("Tweet text: I have no problems Label: ",
SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]),
PromptAdapterRequest("complain", 3, pa_path, 8)),
]

request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request)
request_id += 1

request_outputs = engine.step()

for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results


def test_multi_prompt_adapters():
engine_args = EngineArgs(model=MODEL_PATH,
max_prompt_adapters=3,
enable_prompt_adapter=True)
engine = LLMEngine.from_engine_args(engine_args)
expected_output = {
' quot;I', 'hate speech', 'no complaint', 'not hate speech'
}
assert do_sample(engine) == expected_output
60 changes: 60 additions & 0 deletions tests/prompt_adapter/test_pa_lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from huggingface_hub import snapshot_download

from vllm import EngineArgs, LLMEngine, SamplingParams
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest

MODEL_PATH = "meta-llama/Llama-2-7b-hf"
pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune")
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")


def do_sample(engine):

prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501

# first prompt with a prompt adapter and second without adapter
prompts = [
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]),
PromptAdapterRequest("hate_speech", 1, pa_path,
8), LoRARequest("sql_test", 1, lora_path)),
(prompt_text,
SamplingParams(temperature=0.0, max_tokens=100,
stop=["[/assistant]"]), None,
LoRARequest("sql_test", 1, lora_path)),
]

request_id = 0
results = set()
while prompts or engine.has_unfinished_requests():
if prompts:
prompt, sampling_params, pa_request, lora_request = prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
prompt_adapter_request=pa_request,
lora_request=lora_request)
request_id += 1

request_outputs = engine.step()

for request_output in request_outputs:
if request_output.finished:
results.add(request_output.outputs[0].text)
return results


def test_lora_prompt_adapter():
engine_args = EngineArgs(model=MODEL_PATH,
enable_prompt_adapter=True,
enable_lora=True,
max_num_seqs=60)
engine = LLMEngine.from_engine_args(engine_args)
result = do_sample(engine)

expected_output = {
" SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501
}
assert result == expected_output
2 changes: 2 additions & 0 deletions tests/spec_decode/e2e/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from vllm.model_executor.utils import set_random_seed
from vllm.multimodal import MultiModalData
from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob
from vllm.usage.usage_lib import UsageContext
Expand Down Expand Up @@ -92,6 +93,7 @@ def generate(
use_tqdm: bool = True,
lora_request: Optional[LoRARequest] = None,
multi_modal_data: Optional[MultiModalData] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> List[RequestOutput]:

if prompts is None:
Expand Down
1 change: 1 addition & 0 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
cache_config=engine_config.cache_config,
load_config=engine_config.load_config,
lora_config=engine_config.lora_config,
prompt_adapter_config=engine_config.prompt_adapter_config,
is_driver_worker=True,
)
return model_runner
Expand Down
Empty file.
14 changes: 14 additions & 0 deletions vllm/adapter_commons/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from dataclasses import dataclass
from typing import Tuple


@dataclass
class AdapterMapping:
# Per every token in input_ids:
index_mapping: Tuple[int, ...]
# Per sampled token:
prompt_mapping: Tuple[int, ...]

def __post_init__(self):
self.index_mapping = tuple(self.index_mapping)
self.prompt_mapping = tuple(self.prompt_mapping)
104 changes: 104 additions & 0 deletions vllm/adapter_commons/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Hashable, Optional, TypeVar

from torch import nn

from vllm.logger import init_logger
from vllm.utils import LRUCache

logger = init_logger(__name__)


class AdapterModel(ABC):

def __init__(self, model_id=None):
self.id = model_id

@abstractmethod
def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs):
# Common initialization code
# Load weights or embeddings from local checkpoint
raise NotImplementedError("Subclasses must implement this method.")


T = TypeVar('T')


class AdapterLRUCache(LRUCache[T]):

def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable],
None]):
super().__init__(capacity)
self.deactivate_fn = deactivate_fn

def _on_remove(self, key: Hashable, value: T):
logger.debug("Removing adapter int id: %d", key)
self.deactivate_fn(key)
return super()._on_remove(key, value)


class AdapterModelManager(ABC):

def __init__(
self,
model: nn.Module,
):
"""Create a AdapterModelManager and adapter for a given model.
Args:
model: the model to be adapted.
"""
self.model: nn.Module = model
self._registered_adapters: Dict[int, Any] = {}
# Dict instead of a Set for compatibility with LRUCache.
self._active_adapters: Dict[int, None] = {}
self.adapter_type = 'Adapter'
self._last_mapping = None

def __len__(self) -> int:
return len(self._registered_adapters)

@property
@abstractmethod
def adapter_slots(self):
...

@property
@abstractmethod
def capacity(self):
...

@abstractmethod
def activate_adapter(self, adapter_id: int) -> bool:
...

@abstractmethod
def deactivate_adapter(self, adapter_id: int) -> bool:
...

@abstractmethod
def add_adapter(self, adapter: Any) -> bool:
...

@abstractmethod
def set_adapter_mapping(self, mapping: Any) -> None:
...

@abstractmethod
def remove_adapter(self, adapter_id: int) -> bool:
...

@abstractmethod
def remove_all_adapters(self):
...

@abstractmethod
def get_adapter(self, adapter_id: int) -> Optional[Any]:
...

@abstractmethod
def list_adapters(self) -> Dict[int, Any]:
...

@abstractmethod
def pin_adapter(self, adapter_id: int) -> bool:
...
25 changes: 25 additions & 0 deletions vllm/adapter_commons/request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from abc import abstractmethod
from dataclasses import dataclass


@dataclass
class AdapterRequest:
"""
Base class for adapter requests.
"""

@property
@abstractmethod
def adapter_id(self):
...

def __post_init__(self):
if self.adapter_id < 1:
raise ValueError(f"id must be > 0, got {self.adapter_id}")

def __eq__(self, value: object) -> bool:
return isinstance(
value, self.__class__) and self.adapter_id == value.adapter_id

def __hash__(self) -> int:
return hash(self.adapter_id)
Loading

0 comments on commit 0558bcc

Please sign in to comment.