From 0558bcc94babb63b8af509709e6dc36ee6da76f3 Mon Sep 17 00:00:00 2001 From: Prashant Gupta Date: Mon, 1 Jul 2024 12:07:41 -0700 Subject: [PATCH] Squash 4645 --- format.sh | 1 + tests/lora/test_long_context.py | 9 +- tests/lora/test_lora_manager.py | 326 ++++++++--------- tests/prompt_adapter/test_bloom.py | 39 ++ .../test_multi_adapter_inference.py | 52 +++ tests/prompt_adapter/test_pa_lora.py | 60 ++++ tests/spec_decode/e2e/conftest.py | 2 + tests/worker/test_model_runner.py | 1 + vllm/adapter_commons/__init__.py | 0 vllm/adapter_commons/layers.py | 14 + vllm/adapter_commons/models.py | 104 ++++++ vllm/adapter_commons/request.py | 25 ++ vllm/adapter_commons/utils.py | 90 +++++ vllm/adapter_commons/worker_manager.py | 36 ++ vllm/config.py | 35 ++ vllm/core/scheduler.py | 12 + vllm/engine/arg_utils.py | 47 ++- vllm/engine/async_llm_engine.py | 38 +- vllm/engine/llm_engine.py | 65 +++- vllm/entrypoints/llm.py | 29 +- vllm/executor/cpu_executor.py | 15 + vllm/executor/executor_base.py | 44 ++- vllm/executor/gpu_executor.py | 21 ++ vllm/executor/ray_xpu_executor.py | 7 +- vllm/executor/xpu_executor.py | 7 +- vllm/lora/layers.py | 12 +- vllm/lora/models.py | 175 +++++---- vllm/lora/request.py | 25 +- vllm/lora/worker_manager.py | 215 ++++------- vllm/prompt_adapter/__init__.py | 0 vllm/prompt_adapter/layers.py | 80 +++++ vllm/prompt_adapter/models.py | 340 ++++++++++++++++++ vllm/prompt_adapter/request.py | 30 ++ vllm/prompt_adapter/worker_manager.py | 173 +++++++++ vllm/sequence.py | 32 ++ vllm/spec_decode/draft_model_runner.py | 13 +- vllm/worker/cpu_model_runner.py | 9 +- vllm/worker/cpu_worker.py | 7 +- vllm/worker/embedding_model_runner.py | 15 +- vllm/worker/model_runner.py | 122 ++++++- vllm/worker/worker.py | 22 +- vllm/worker/xpu_model_runner.py | 6 +- vllm/worker/xpu_worker.py | 7 +- 43 files changed, 1840 insertions(+), 522 deletions(-) create mode 100644 tests/prompt_adapter/test_bloom.py create mode 100644 tests/prompt_adapter/test_multi_adapter_inference.py create mode 100644 tests/prompt_adapter/test_pa_lora.py create mode 100644 vllm/adapter_commons/__init__.py create mode 100644 vllm/adapter_commons/layers.py create mode 100644 vllm/adapter_commons/models.py create mode 100644 vllm/adapter_commons/request.py create mode 100644 vllm/adapter_commons/utils.py create mode 100644 vllm/adapter_commons/worker_manager.py create mode 100644 vllm/prompt_adapter/__init__.py create mode 100644 vllm/prompt_adapter/layers.py create mode 100644 vllm/prompt_adapter/models.py create mode 100644 vllm/prompt_adapter/request.py create mode 100644 vllm/prompt_adapter/worker_manager.py diff --git a/format.sh b/format.sh index 8c54b56302d5..5edc868f9f70 100755 --- a/format.sh +++ b/format.sh @@ -111,6 +111,7 @@ mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/model_executor --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml mypy vllm/logging --config-file pyproject.toml +mypy vllm/prompt_adapter --config-file pyproject.toml mypy tests --config-file pyproject.toml diff --git a/tests/lora/test_long_context.py b/tests/lora/test_long_context.py index b50784a205af..853fd9fb3ce7 100644 --- a/tests/lora/test_long_context.py +++ b/tests/lora/test_long_context.py @@ -92,11 +92,10 @@ def batched_generate( for input in inputs: prompt, sampling_param, lora_req = input # Add requests to the engine and run the engine - llm._validate_and_add_requests( - prompt, - sampling_param, - lora_request=lora_req, - ) + llm._validate_and_add_requests(prompt, + sampling_param, + lora_request=lora_req, + prompt_adapter_request=None) outputs = llm._run_engine(use_tqdm=True) return [outputs[i].outputs[0].text.strip() for i in range(len(outputs))] diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 2133bce14957..7bff9e1fbcdc 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -127,37 +127,37 @@ def test_lora_model_manager(dist_init, dummy_model): model, 2, 2, 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2)) assert all(x is None for x in manager.lora_index_to_id) - assert manager.add_lora(model_lora1) - assert manager.activate_lora(1) + assert manager.add_adapter(model_lora1) + assert manager.activate_adapter(1) assert manager.lora_index_to_id[0] == 1 - assert not manager.add_lora(model_lora1) - assert not manager.activate_lora(1) - assert manager.add_lora(model_lora2) - assert manager.activate_lora(2) + assert not manager.add_adapter(model_lora1) + assert not manager.activate_adapter(1) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(2) assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 - assert not manager.add_lora(model_lora2) - assert not manager.activate_lora(2) - assert manager.add_lora(model_lora3) + assert not manager.add_adapter(model_lora2) + assert not manager.activate_adapter(2) + assert manager.add_adapter(model_lora3) assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 with pytest.raises(ValueError): - assert manager.activate_lora(3) + assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 - assert manager.remove_lora(model_lora2.id) + assert manager.remove_adapter(model_lora2.id) assert manager.lora_index_to_id[1] is None - assert not manager.remove_lora(model_lora2.id) - assert manager.remove_lora(model_lora1.id) - assert not manager.remove_lora(model_lora1.id) - assert manager.add_lora(model_lora1) + assert not manager.remove_adapter(model_lora2.id) + assert manager.remove_adapter(model_lora1.id) + assert not manager.remove_adapter(model_lora1.id) + assert manager.add_adapter(model_lora1) assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] is None - assert manager.add_lora(model_lora2) - assert manager.activate_lora(3) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] is None - assert manager.activate_lora(2) + assert manager.activate_adapter(2) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 2 @@ -173,70 +173,70 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model): model, 2, 2, 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2)) assert all(x is None for x in manager.lora_index_to_id) - assert manager.add_lora(model_lora1) - assert manager.activate_lora(1) + assert manager.add_adapter(model_lora1) + assert manager.activate_adapter(1) assert manager.lora_index_to_id[0] == 1 - assert not manager.add_lora(model_lora1) - assert not manager.activate_lora(1) - assert manager.add_lora(model_lora2) - assert manager.activate_lora(2) + assert not manager.add_adapter(model_lora1) + assert not manager.activate_adapter(1) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(2) assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 - assert not manager.add_lora(model_lora2) - assert not manager.activate_lora(2) - assert manager.add_lora(model_lora3) + assert not manager.add_adapter(model_lora2) + assert not manager.activate_adapter(2) + assert manager.add_adapter(model_lora3) assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 - assert manager.activate_lora(3) + assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 2 - assert manager.remove_lora(model_lora2.id) + assert manager.remove_adapter(model_lora2.id) assert manager.lora_index_to_id[1] is None - assert not manager.remove_lora(model_lora2.id) - assert manager.remove_lora(model_lora1.id) - assert not manager.remove_lora(model_lora1.id) - assert manager.add_lora(model_lora1) - assert manager.activate_lora(1) + assert not manager.remove_adapter(model_lora2.id) + assert manager.remove_adapter(model_lora1.id) + assert not manager.remove_adapter(model_lora1.id) + assert manager.add_adapter(model_lora1) + assert manager.activate_adapter(1) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 1 - assert manager.add_lora(model_lora2) - assert manager.deactivate_lora(3) + assert manager.add_adapter(model_lora2) + assert manager.deactivate_adapter(3) assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] == 1 - assert manager.activate_lora(2) + assert manager.activate_adapter(2) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 1 - assert manager.activate_lora(3) + assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 3 - assert manager.pin_lora(2) + assert manager.pin_adapter(2) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 3 - assert manager.activate_lora(1) + assert manager.activate_adapter(1) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 1 - assert manager.deactivate_lora(2) + assert manager.deactivate_adapter(2) assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] == 1 - assert manager.activate_lora(3) + assert manager.activate_adapter(3) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 1 - assert manager.pin_lora(3) - assert manager.pin_lora(1) + assert manager.pin_adapter(3) + assert manager.pin_adapter(1) with pytest.raises(RuntimeError): - assert manager.pin_lora(2) + assert manager.pin_adapter(2) assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 1 with pytest.raises(RuntimeError): - assert manager.activate_lora(2) + assert manager.activate_adapter(2) - assert manager.deactivate_lora(3) - assert manager.pin_lora(2) + assert manager.deactivate_adapter(3) + assert manager.pin_adapter(2) assert manager.lora_index_to_id[0] == 2 assert manager.lora_index_to_id[1] == 1 - assert manager.remove_lora(3) + assert manager.remove_adapter(3) with pytest.raises(ValueError): - assert manager.pin_lora(3) + assert manager.pin_adapter(3) def test_lru_lora_model_manager(dist_init, dummy_model): @@ -256,168 +256,169 @@ def test_lru_lora_model_manager(dist_init, dummy_model): assert all(x is None for x in manager.lora_index_to_id) # Add up to capacity - assert manager.add_lora(model_lora1) - assert manager.add_lora(model_lora2) - assert manager.activate_lora(1) - assert manager.activate_lora(2) + assert manager.add_adapter(model_lora1) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(1) + assert manager.activate_adapter(2) - assert set(manager.list_loras()) == {1, 2} + assert set(manager.list_adapters()) == {1, 2} assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 # Add over capacity - assert manager.add_lora(model_lora3) - assert manager.add_lora(model_lora4) - assert manager.activate_lora(3) - assert manager.activate_lora(4) + assert manager.add_adapter(model_lora3) + assert manager.add_adapter(model_lora4) + assert manager.activate_adapter(3) + assert manager.activate_adapter(4) - assert set(manager.list_loras()) == {3, 4} + assert set(manager.list_adapters()) == {3, 4} assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 4 # Add 3 again to move it to the top and then add 2 # should return false since it's in already - assert not manager.add_lora(model_lora3) - assert not manager.activate_lora(3) - assert manager.add_lora(model_lora2) - assert manager.activate_lora(2) + assert not manager.add_adapter(model_lora3) + assert not manager.activate_adapter(3) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(2) - assert set(manager.list_loras()) == {3, 2} + assert set(manager.list_adapters()) == {3, 2} assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 2 # Remove manually - assert manager.remove_lora(3) - assert not manager.remove_lora(3) + assert manager.remove_adapter(3) + assert not manager.remove_adapter(3) - assert set(manager.list_loras()) == {2} + assert set(manager.list_adapters()) == {2} assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] == 2 - assert manager.add_lora(model_lora3) - assert manager.activate_lora(3) - assert manager.add_lora(model_lora4) - assert manager.activate_lora(4) + assert manager.add_adapter(model_lora3) + assert manager.activate_adapter(3) + assert manager.add_adapter(model_lora4) + assert manager.activate_adapter(4) - assert set(manager.list_loras()) == {3, 4} + assert set(manager.list_adapters()) == {3, 4} assert manager.lora_index_to_id[0] == 3 assert manager.lora_index_to_id[1] == 4 - assert manager.remove_oldest_lora() - assert set(manager.list_loras()) == {4} + assert manager.remove_oldest_adapter() + assert set(manager.list_adapters()) == {4} assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] == 4 - assert manager.remove_oldest_lora() - assert set(manager.list_loras()) == set() + assert manager.remove_oldest_adapter() + assert set(manager.list_adapters()) == set() assert all(x is None for x in manager.lora_index_to_id) - assert not manager.remove_oldest_lora() - assert set(manager.list_loras()) == set() + assert not manager.remove_oldest_adapter() + assert set(manager.list_adapters()) == set() assert all(x is None for x in manager.lora_index_to_id) # pinning - assert manager.add_lora(model_lora3) - assert manager.activate_lora(3) - assert manager.add_lora(model_lora4) - assert manager.activate_lora(4) - assert set(manager.list_loras()) == {3, 4} + assert manager.add_adapter(model_lora3) + assert manager.activate_adapter(3) + assert manager.add_adapter(model_lora4) + assert manager.activate_adapter(4) + assert set(manager.list_adapters()) == {3, 4} with pytest.raises(ValueError): - assert manager.pin_lora(1) - assert manager.pin_lora(3) + assert manager.pin_adapter(1) + assert manager.pin_adapter(3) # Remove manually - assert manager.remove_lora(3) - assert not manager.remove_lora(3) + assert manager.remove_adapter(3) + assert not manager.remove_adapter(3) - assert set(manager.list_loras()) == {4} + assert set(manager.list_adapters()) == {4} assert manager.lora_index_to_id[0] is None assert manager.lora_index_to_id[1] == 4 - assert manager.add_lora(model_lora1) - assert manager.pin_lora(1) - assert manager.add_lora(model_lora2) - assert manager.activate_lora(2) + assert manager.add_adapter(model_lora1) + assert manager.pin_adapter(1) + assert manager.add_adapter(model_lora2) + assert manager.activate_adapter(2) - assert set(manager.list_loras()) == {1, 2} + assert set(manager.list_adapters()) == {1, 2} assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] == 2 - assert manager.remove_oldest_lora() - assert set(manager.list_loras()) == {1} + assert manager.remove_oldest_adapter() + assert set(manager.list_adapters()) == {1} assert manager.lora_index_to_id[0] == 1 assert manager.lora_index_to_id[1] is None with pytest.raises(RuntimeError): - assert manager.remove_oldest_lora() + assert manager.remove_oldest_adapter() - assert set(manager.list_loras()) == {1} + assert set(manager.list_adapters()) == {1} -def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, - sql_lora_files): +def test_lru_cache_worker_adapter_manager(llama_2_7b_model_extra_embeddings, + sql_lora_files): lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) - worker_lora_manager = LRUCacheWorkerLoRAManager( + worker_adapter_manager = LRUCacheWorkerLoRAManager( 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"), EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) - worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings) + worker_adapter_manager.create_lora_manager( + llama_2_7b_model_extra_embeddings) mapping = LoRAMapping([], []) - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager.list_adapters() == {1, 2} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("3", 3, sql_lora_files), LoRARequest("4", 4, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2, 3, 4} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 3 - assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 + assert worker_adapter_manager.list_adapters() == {1, 2, 3, 4} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 3 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files), LoRARequest("5", 5, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2, 4, 5} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 - assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 + assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2, 4, 5} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 - assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 + assert worker_adapter_manager.list_adapters() == {1, 2, 4, 5} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 4 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("6", 6, sql_lora_files), LoRARequest("7", 7, sql_lora_files), LoRARequest("8", 8, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 6, 7, 8} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 7 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 8 - assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 6 + assert worker_adapter_manager.list_adapters() == {1, 6, 7, 8} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 7 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 8 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[3] == 6 # Over capacity with pytest.raises(RuntimeError): - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("10", 10, sql_lora_files), LoRARequest("11", 11, sql_lora_files), LoRARequest("12", 12, sql_lora_files), @@ -426,68 +427,69 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, ], mapping) -def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, - sql_lora_files): +def test_worker_adapter_manager(llama_2_7b_model_extra_embeddings, + sql_lora_files): # Should remove every LoRA not specified in the request. lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) - worker_lora_manager = WorkerLoRAManager( + worker_adapter_manager = WorkerLoRAManager( 4, 2, llama_2_7b_model_extra_embeddings.unpadded_vocab_size - lora_config.lora_extra_vocab_size, lora_config, torch.device("cuda"), EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES) - worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings) + worker_adapter_manager.create_lora_manager( + llama_2_7b_model_extra_embeddings) mapping = LoRAMapping([], []) - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager.list_adapters() == {1, 2} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("3", 3, sql_lora_files), LoRARequest("4", 4, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 3, 4} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 3 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 4 + assert worker_adapter_manager.list_adapters() == {1, 3, 4} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 3 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 4 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files), LoRARequest("5", 5, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1, 2, 5} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 + assert worker_adapter_manager.list_adapters() == {1, 2, 5} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 2 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 5 - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {1} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] is None - assert worker_lora_manager._lora_manager.lora_index_to_id[2] is None + assert worker_adapter_manager.list_adapters() == {1} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 1 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] is None + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] is None - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("6", 6, sql_lora_files), LoRARequest("7", 7, sql_lora_files), LoRARequest("8", 8, sql_lora_files) ], mapping) - assert worker_lora_manager.list_loras() == {6, 7, 8} - assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 8 - assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 6 - assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 7 + assert worker_adapter_manager.list_adapters() == {6, 7, 8} + assert worker_adapter_manager._adapter_manager.lora_index_to_id[0] == 8 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[1] == 6 + assert worker_adapter_manager._adapter_manager.lora_index_to_id[2] == 7 # Over capacity with pytest.raises(RuntimeError): - worker_lora_manager.set_active_loras([ + worker_adapter_manager.set_active_adapters([ LoRARequest("10", 10, sql_lora_files), LoRARequest("11", 11, sql_lora_files), LoRARequest("12", 12, sql_lora_files), @@ -525,8 +527,8 @@ def test_packed_loras(dist_init, dummy_model_gate_up): assert isinstance(model.get_submodule("gate_up_proj"), MergedColumnParallelLinearWithLoRA) - assert manager.add_lora(model_lora) - assert manager.add_lora(model_lora1) + assert manager.add_adapter(model_lora) + assert manager.add_adapter(model_lora1) packed_lora = model_lora.get_lora("gate_up_proj") assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights) diff --git a/tests/prompt_adapter/test_bloom.py b/tests/prompt_adapter/test_bloom.py new file mode 100644 index 000000000000..7c13a81b6f2c --- /dev/null +++ b/tests/prompt_adapter/test_bloom.py @@ -0,0 +1,39 @@ +import vllm +from vllm.prompt_adapter.request import PromptAdapterRequest + +MODEL_PATH = "bigscience/bloomz-560m" +PA_PATH = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM' + + +def do_sample(llm, pa_name: str, pa_id: int): + + prompts = [ + "Tweet text : @nationalgridus I have no water and the bill is \ + current and paid. Can you do something about this? Label : ", + "Tweet text : @nationalgridus Looks good thanks! Label : " + ] + sampling_params = vllm.SamplingParams(temperature=0.0, + max_tokens=3, + stop_token_ids=[3]) + + outputs = llm.generate(prompts, + sampling_params, + prompt_adapter_request=PromptAdapterRequest( + pa_name, pa_id, PA_PATH, 8) if pa_id else None) + + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +def test_twitter_prompt_adapter(): + llm = vllm.LLM(MODEL_PATH, enable_prompt_adapter=True) + + expected_output = ['complaint', 'no complaint'] + + assert do_sample(llm, "twitter_pa", pa_id=1) == expected_output diff --git a/tests/prompt_adapter/test_multi_adapter_inference.py b/tests/prompt_adapter/test_multi_adapter_inference.py new file mode 100644 index 000000000000..0cc8c8bc50fd --- /dev/null +++ b/tests/prompt_adapter/test_multi_adapter_inference.py @@ -0,0 +1,52 @@ +from vllm import EngineArgs, LLMEngine, SamplingParams +from vllm.prompt_adapter.request import PromptAdapterRequest + +MODEL_PATH = "bigscience/bloomz-560m" +pa_path = 'stevhliu/bloomz-560m_PROMPT_TUNING_CAUSAL_LM' +pa_path2 = 'swapnilbp/angry_tweet_ptune' + + +def do_sample(engine): + + prompts = [ + ("Tweet text: I have complaints! Label: ", + SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), + PromptAdapterRequest("hate_speech", 1, pa_path2, 8)), + ("Tweet text: I have no problems Label: ", + SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), + PromptAdapterRequest("hate_speech2", 2, pa_path2, 8)), + ("Tweet text: I have complaints! Label: ", + SamplingParams(temperature=0.0, max_tokens=3), None), + ("Tweet text: I have no problems Label: ", + SamplingParams(temperature=0.0, max_tokens=3, stop_token_ids=[3]), + PromptAdapterRequest("complain", 3, pa_path, 8)), + ] + + request_id = 0 + results = set() + while prompts or engine.has_unfinished_requests(): + if prompts: + prompt, sampling_params, pa_request = prompts.pop(0) + engine.add_request(str(request_id), + prompt, + sampling_params, + prompt_adapter_request=pa_request) + request_id += 1 + + request_outputs = engine.step() + + for request_output in request_outputs: + if request_output.finished: + results.add(request_output.outputs[0].text) + return results + + +def test_multi_prompt_adapters(): + engine_args = EngineArgs(model=MODEL_PATH, + max_prompt_adapters=3, + enable_prompt_adapter=True) + engine = LLMEngine.from_engine_args(engine_args) + expected_output = { + ' quot;I', 'hate speech', 'no complaint', 'not hate speech' + } + assert do_sample(engine) == expected_output diff --git a/tests/prompt_adapter/test_pa_lora.py b/tests/prompt_adapter/test_pa_lora.py new file mode 100644 index 000000000000..89f349fec633 --- /dev/null +++ b/tests/prompt_adapter/test_pa_lora.py @@ -0,0 +1,60 @@ +from huggingface_hub import snapshot_download + +from vllm import EngineArgs, LLMEngine, SamplingParams +from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest + +MODEL_PATH = "meta-llama/Llama-2-7b-hf" +pa_path = snapshot_download(repo_id="swapnilbp/llama_tweet_ptune") +lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + + +def do_sample(engine): + + prompt_text = "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]" # noqa: E501 + + # first prompt with a prompt adapter and second without adapter + prompts = [ + (prompt_text, + SamplingParams(temperature=0.0, max_tokens=100, + stop=["[/assistant]"]), + PromptAdapterRequest("hate_speech", 1, pa_path, + 8), LoRARequest("sql_test", 1, lora_path)), + (prompt_text, + SamplingParams(temperature=0.0, max_tokens=100, + stop=["[/assistant]"]), None, + LoRARequest("sql_test", 1, lora_path)), + ] + + request_id = 0 + results = set() + while prompts or engine.has_unfinished_requests(): + if prompts: + prompt, sampling_params, pa_request, lora_request = prompts.pop(0) + engine.add_request(str(request_id), + prompt, + sampling_params, + prompt_adapter_request=pa_request, + lora_request=lora_request) + request_id += 1 + + request_outputs = engine.step() + + for request_output in request_outputs: + if request_output.finished: + results.add(request_output.outputs[0].text) + return results + + +def test_lora_prompt_adapter(): + engine_args = EngineArgs(model=MODEL_PATH, + enable_prompt_adapter=True, + enable_lora=True, + max_num_seqs=60) + engine = LLMEngine.from_engine_args(engine_args) + result = do_sample(engine) + + expected_output = { + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' " # noqa: E501 + } + assert result == expected_output diff --git a/tests/spec_decode/e2e/conftest.py b/tests/spec_decode/e2e/conftest.py index 60dfe33f2918..a77c8246f830 100644 --- a/tests/spec_decode/e2e/conftest.py +++ b/tests/spec_decode/e2e/conftest.py @@ -13,6 +13,7 @@ from vllm.model_executor.utils import set_random_seed from vllm.multimodal import MultiModalData from vllm.outputs import RequestOutput +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import Logprob from vllm.usage.usage_lib import UsageContext @@ -92,6 +93,7 @@ def generate( use_tqdm: bool = True, lora_request: Optional[LoRARequest] = None, multi_modal_data: Optional[MultiModalData] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> List[RequestOutput]: if prompts is None: diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index e1775790c0a0..b5742c433861 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -23,6 +23,7 @@ def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: cache_config=engine_config.cache_config, load_config=engine_config.load_config, lora_config=engine_config.lora_config, + prompt_adapter_config=engine_config.prompt_adapter_config, is_driver_worker=True, ) return model_runner diff --git a/vllm/adapter_commons/__init__.py b/vllm/adapter_commons/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/adapter_commons/layers.py b/vllm/adapter_commons/layers.py new file mode 100644 index 000000000000..3ed60678b52f --- /dev/null +++ b/vllm/adapter_commons/layers.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass +from typing import Tuple + + +@dataclass +class AdapterMapping: + # Per every token in input_ids: + index_mapping: Tuple[int, ...] + # Per sampled token: + prompt_mapping: Tuple[int, ...] + + def __post_init__(self): + self.index_mapping = tuple(self.index_mapping) + self.prompt_mapping = tuple(self.prompt_mapping) \ No newline at end of file diff --git a/vllm/adapter_commons/models.py b/vllm/adapter_commons/models.py new file mode 100644 index 000000000000..6939b1405f3e --- /dev/null +++ b/vllm/adapter_commons/models.py @@ -0,0 +1,104 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, Hashable, Optional, TypeVar + +from torch import nn + +from vllm.logger import init_logger +from vllm.utils import LRUCache + +logger = init_logger(__name__) + + +class AdapterModel(ABC): + + def __init__(self, model_id=None): + self.id = model_id + + @abstractmethod + def from_local_checkpoint(cls, model_dir, model_id=None, **kwargs): + # Common initialization code + # Load weights or embeddings from local checkpoint + raise NotImplementedError("Subclasses must implement this method.") + + +T = TypeVar('T') + + +class AdapterLRUCache(LRUCache[T]): + + def __init__(self, capacity: int, deactivate_fn: Callable[[Hashable], + None]): + super().__init__(capacity) + self.deactivate_fn = deactivate_fn + + def _on_remove(self, key: Hashable, value: T): + logger.debug("Removing adapter int id: %d", key) + self.deactivate_fn(key) + return super()._on_remove(key, value) + + +class AdapterModelManager(ABC): + + def __init__( + self, + model: nn.Module, + ): + """Create a AdapterModelManager and adapter for a given model. + Args: + model: the model to be adapted. + """ + self.model: nn.Module = model + self._registered_adapters: Dict[int, Any] = {} + # Dict instead of a Set for compatibility with LRUCache. + self._active_adapters: Dict[int, None] = {} + self.adapter_type = 'Adapter' + self._last_mapping = None + + def __len__(self) -> int: + return len(self._registered_adapters) + + @property + @abstractmethod + def adapter_slots(self): + ... + + @property + @abstractmethod + def capacity(self): + ... + + @abstractmethod + def activate_adapter(self, adapter_id: int) -> bool: + ... + + @abstractmethod + def deactivate_adapter(self, adapter_id: int) -> bool: + ... + + @abstractmethod + def add_adapter(self, adapter: Any) -> bool: + ... + + @abstractmethod + def set_adapter_mapping(self, mapping: Any) -> None: + ... + + @abstractmethod + def remove_adapter(self, adapter_id: int) -> bool: + ... + + @abstractmethod + def remove_all_adapters(self): + ... + + @abstractmethod + def get_adapter(self, adapter_id: int) -> Optional[Any]: + ... + + @abstractmethod + def list_adapters(self) -> Dict[int, Any]: + ... + + @abstractmethod + def pin_adapter(self, adapter_id: int) -> bool: + ... diff --git a/vllm/adapter_commons/request.py b/vllm/adapter_commons/request.py new file mode 100644 index 000000000000..69775ab7d454 --- /dev/null +++ b/vllm/adapter_commons/request.py @@ -0,0 +1,25 @@ +from abc import abstractmethod +from dataclasses import dataclass + + +@dataclass +class AdapterRequest: + """ + Base class for adapter requests. + """ + + @property + @abstractmethod + def adapter_id(self): + ... + + def __post_init__(self): + if self.adapter_id < 1: + raise ValueError(f"id must be > 0, got {self.adapter_id}") + + def __eq__(self, value: object) -> bool: + return isinstance( + value, self.__class__) and self.adapter_id == value.adapter_id + + def __hash__(self) -> int: + return hash(self.adapter_id) diff --git a/vllm/adapter_commons/utils.py b/vllm/adapter_commons/utils.py new file mode 100644 index 000000000000..6c5411f7d3d5 --- /dev/null +++ b/vllm/adapter_commons/utils.py @@ -0,0 +1,90 @@ +from typing import Any, Callable, Dict, Optional, Set + + +## model functions +def deactivate_adapter(adapter_id: int, active_adapters: Dict[int, None], + deactivate_func: Callable) -> bool: + if adapter_id in active_adapters: + deactivate_func(adapter_id) + active_adapters.pop(adapter_id) + return True + return False + + +def add_adapter(adapter: Any, registered_adapters: Dict[int, Any], + capacity: int, add_func: Callable) -> bool: + if adapter.id not in registered_adapters: + if len(registered_adapters) >= capacity: + raise RuntimeError('No free adapter slots.') + add_func(adapter) + registered_adapters[adapter.id] = adapter + return True + return False + + +def set_adapter_mapping(mapping: Any, last_mapping: Any, + set_mapping_func: Callable) -> Any: + if last_mapping != mapping: + set_mapping_func(mapping) + return mapping + return last_mapping + + +def remove_adapter(adapter_id: int, registered_adapters: Dict[int, Any], + deactivate_func: Callable) -> bool: + deactivate_func(adapter_id) + return bool(registered_adapters.pop(adapter_id, None)) + + +def list_adapters(registered_adapters: Dict[int, Any]) -> Dict[int, Any]: + return dict(registered_adapters) + + +def get_adapter(adapter_id: int, + registered_adapters: Dict[int, Any]) -> Optional[Any]: + return registered_adapters.get(adapter_id, None) + + +## worker functions +def set_active_adapters_worker(requests: Set[Any], mapping: Optional[Any], + apply_adapters_func, + set_adapter_mapping_func) -> None: + apply_adapters_func(requests) + set_adapter_mapping_func(mapping) + + +def add_adapter_worker(adapter_request: Any, list_adapters_func, + load_adapter_func, add_adapter_func, + activate_adapter_func) -> bool: + if adapter_request.adapter_id in list_adapters_func(): + return False + loaded_adapter = load_adapter_func(adapter_request) + loaded = add_adapter_func(loaded_adapter) + activate_adapter_func(loaded_adapter.id) + return loaded + + +def apply_adapters_worker(adapter_requests: Set[Any], list_adapters_func, + adapter_slots: int, remove_adapter_func, + add_adapter_func) -> None: + models_that_exist = list_adapters_func() + models_map = { + adapter_request.adapter_id: adapter_request + for adapter_request in adapter_requests if adapter_request + } + if len(models_map) > adapter_slots: + raise RuntimeError( + f"Number of requested models ({len(models_map)}) is greater " + f"than the number of GPU model slots " + f"({adapter_slots}).") + new_models = set(models_map) + models_to_add = new_models - models_that_exist + models_to_remove = models_that_exist - new_models + for adapter_id in models_to_remove: + remove_adapter_func(adapter_id) + for adapter_id in models_to_add: + add_adapter_func(models_map[adapter_id]) + + +def list_adapters_worker(adapter_manager_list_adapters_func) -> Set[int]: + return set(adapter_manager_list_adapters_func()) diff --git a/vllm/adapter_commons/worker_manager.py b/vllm/adapter_commons/worker_manager.py new file mode 100644 index 000000000000..acf18993af6d --- /dev/null +++ b/vllm/adapter_commons/worker_manager.py @@ -0,0 +1,36 @@ +from abc import ABC, abstractmethod +from typing import Any, Optional, Set + +import torch + + +class AbstractWorkerManager(ABC): + + def __init__(self, device: torch.device): + self.device = device + + @property + @abstractmethod + def is_enabled(self) -> bool: + ... + + @abstractmethod + def set_active_adapters(self, requests: Set[Any], + mapping: Optional[Any]) -> None: + ... + + @abstractmethod + def add_adapter(self, adapter_request: Any) -> bool: + ... + + @abstractmethod + def remove_adapter(self, adapter_id: int) -> bool: + ... + + @abstractmethod + def remove_all_adapters(self): + ... + + @abstractmethod + def list_adapters(self) -> Set[int]: + ... diff --git a/vllm/config.py b/vllm/config.py index 9854f175065a..b0cb99f40858 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1250,6 +1250,37 @@ def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): raise ValueError("LoRA is not supported with chunked prefill yet.") +@dataclass +class PromptAdapterConfig: + max_prompt_adapters: int + max_prompt_adapter_token: int = 10 + max_cpu_prompt_adapters: Optional[int] = None + prompt_adapter_dtype: Optional[torch.dtype] = None + + def __post_init__(self): + library_name = 'peft' + try: + __import__(library_name) + except ImportError as e: + raise ImportError( + f"'{library_name}' is not installed for prompt adapter support." + f"Please install it using 'pip install {library_name}'." + ) from e + + if self.max_prompt_adapters < 1: + raise ValueError(f"max_prompt_adapters " + f"({self.max_prompt_adapters}) must be >= 1.") + if self.max_cpu_prompt_adapters is None: + self.max_cpu_prompt_adapters = self.max_prompt_adapters + + def verify_with_model_config(self, model_config: ModelConfig): + if self.prompt_adapter_dtype in (None, "auto"): + self.prompt_adapter_dtype = model_config.dtype + elif isinstance(self.prompt_adapter_dtype, str): + self.prompt_adapter_dtype = getattr(torch, + self.prompt_adapter_dtype) + + @dataclass class VisionLanguageConfig: """Configs the input data format and how models should run for @@ -1548,6 +1579,7 @@ class EngineConfig: speculative_config: Optional[SpeculativeConfig] decoding_config: Optional[DecodingConfig] observability_config: Optional[ObservabilityConfig] + prompt_adapter_config: Optional[PromptAdapterConfig] def __post_init__(self): """Verify configs are valid & consistent with each other. @@ -1559,6 +1591,9 @@ def __post_init__(self): self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( self.scheduler_config) + if self.prompt_adapter_config: + self.prompt_adapter_config.verify_with_model_config( + self.model_config) def to_dict(self): """Return the configs as a dictionary, for use in **kwargs. diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 48c34625c08a..26fe602d8441 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -11,6 +11,7 @@ from vllm.core.policy import Policy, PolicyFactory from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) @@ -139,6 +140,8 @@ def __post_init__(self): if self.num_loras > 0: self._sort_by_lora_ids() + self.num_prompt_adapters: int = len(self.prompt_adapter_requests) + def is_empty(self) -> bool: # NOTE: We do not consider the ignored sequence groups. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in @@ -157,6 +160,14 @@ def lora_requests(self) -> Set[LoRARequest]: if g.seq_group.lora_request is not None } + @property + def prompt_adapter_requests(self) -> Set[PromptAdapterRequest]: + return { + g.seq_group.prompt_adapter_request + for g in self.scheduled_seq_groups + if g.seq_group.prompt_adapter_request is not None + } + @dataclass class SchedulerRunningOutputs: @@ -1006,6 +1017,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: # `multi_modal_data` will be None. multi_modal_data=seq_group.multi_modal_data if scheduler_outputs.num_prefill_groups > 0 else None, + prompt_adapter_request=seq_group.prompt_adapter_request, ) seq_group_metadata_list.append(seq_group_metadata) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d4044adfce61..a2791869b848 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -7,7 +7,8 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, EngineConfig, LoadConfig, LoRAConfig, ModelConfig, - ObservabilityConfig, ParallelConfig, SchedulerConfig, + ObservabilityConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, SpeculativeConfig, TokenizerPoolConfig, VisionLanguageConfig) from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -67,6 +68,9 @@ class EngineArgs: enable_lora: bool = False max_loras: int = 1 max_lora_rank: int = 16 + enable_prompt_adapter: bool = False + max_prompt_adapters: int = 1 + max_prompt_adapter_token: int = 10 fully_sharded_loras: bool = False lora_extra_vocab_size: int = 256 long_lora_scaling_factors: Optional[Tuple[float]] = None @@ -506,6 +510,17 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'Enabling this will use the fully sharded layers. ' 'At high sequence length, max rank or ' 'tensor parallel size, this is likely faster.')) + parser.add_argument('--enable-prompt-adapter', + action='store_true', + help='If True, enable handling of PromptAdapters.') + parser.add_argument('--max-prompt-adapters', + type=int, + default=EngineArgs.max_prompt_adapters, + help='Max number of PromptAdapters in a batch.') + parser.add_argument('--max-prompt-adapter-token', + type=int, + default=EngineArgs.max_prompt_adapter_token, + help='Max number of PromptAdapters tokens') parser.add_argument("--device", type=str, default=EngineArgs.device, @@ -815,6 +830,11 @@ def create_engine_config(self, ) -> EngineConfig: model_loader_extra_config=self.model_loader_extra_config, ) + prompt_adapter_config = PromptAdapterConfig( + max_prompt_adapters=self.max_prompt_adapters, + max_prompt_adapter_token=self.max_prompt_adapter_token) \ + if self.enable_prompt_adapter else None + decoding_config = DecodingConfig( guided_decoding_backend=self.guided_decoding_backend) @@ -828,19 +848,18 @@ def create_engine_config(self, ) -> EngineConfig: "Chunked prefill is not supported with sliding window. " "Set --disable-sliding-window to disable sliding window.") - return EngineConfig( - model_config=model_config, - cache_config=cache_config, - parallel_config=parallel_config, - scheduler_config=scheduler_config, - device_config=device_config, - lora_config=lora_config, - vision_language_config=vision_language_config, - speculative_config=speculative_config, - load_config=load_config, - decoding_config=decoding_config, - observability_config=observability_config, - ) + return EngineConfig(model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + lora_config=lora_config, + vision_language_config=vision_language_config, + speculative_config=speculative_config, + load_config=load_config, + decoding_config=decoding_config, + observability_config=observability_config, + prompt_adapter_config=prompt_adapter_config) @dataclass diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7db3bb28c6ee..2645c6009c18 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -18,6 +18,7 @@ from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.usage.usage_lib import UsageContext @@ -263,6 +264,7 @@ async def process_model_inputs_async( request_id: str, inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: if isinstance(inputs, str): inputs = {"prompt": inputs} @@ -278,6 +280,12 @@ async def process_model_inputs_async( else: prompt_token_ids = inputs["prompt_token_ids"] + if prompt_adapter_request: + prompt_token_ids = [ + 0 + ] * prompt_adapter_request.prompt_adapter_num_virtual_tokens + \ + prompt_token_ids + llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, prompt=inputs.get("prompt"), multi_modal_data=inputs.get("multi_modal_data")) @@ -285,13 +293,14 @@ async def process_model_inputs_async( return self.input_processor(llm_inputs) async def add_request_async( - self, - request_id: str, - inputs: PromptInputs, - params: Union[SamplingParams, PoolingParams], - arrival_time: Optional[float] = None, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Dict[str, str]] = None, + self, + request_id: str, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Dict[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " @@ -300,7 +309,10 @@ async def add_request_async( arrival_time = time.time() processed_inputs = await self.process_model_inputs_async( - request_id=request_id, inputs=inputs, lora_request=lora_request) + request_id=request_id, + inputs=inputs, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) self._add_processed_request( request_id=request_id, @@ -308,6 +320,7 @@ async def add_request_async( params=params, arrival_time=arrival_time, lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, ) @@ -573,6 +586,7 @@ async def add_request( arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Dict[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncStream: if self.log_requests: if isinstance(inputs, str): @@ -615,7 +629,7 @@ async def add_request( arrival_time=arrival_time, lora_request=lora_request, trace_headers=trace_headers, - ) + prompt_adapter_request=prompt_adapter_request) return stream @@ -626,6 +640,7 @@ async def generate( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Dict[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -641,6 +656,8 @@ async def generate( request_id: The unique id of the request. lora_request: LoRA request to use for generation, if any. trace_headers: OpenTelemetry trace headers. + prompt_adapter_request: Prompt Adapter request to use + for generation, if any. Yields: The output `RequestOutput` objects from the LLMEngine @@ -695,6 +712,7 @@ async def generate( sampling_params, lora_request=lora_request, trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, ): yield LLMEngine.validate_output(output, RequestOutput) @@ -783,6 +801,7 @@ async def _process_request( *, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Dict[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: """Common logic to process requests with SamplingParams or PoolingParams.""" @@ -795,6 +814,7 @@ async def _process_request( arrival_time=arrival_time, lora_request=lora_request, trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request, ) try: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f7e38c0e6b94..6ce3b14bacd6 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -8,8 +8,8 @@ from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, - ParallelConfig, SchedulerConfig, SpeculativeConfig, - VisionLanguageConfig) + ParallelConfig, PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig, VisionLanguageConfig) from vllm.core.scheduler import (ScheduledSequenceGroup, Scheduler, SchedulerOutputs) from vllm.engine.arg_utils import EngineArgs @@ -27,6 +27,7 @@ from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, RequestOutputFactory) from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, PoolerOutput, SamplerOutput, Sequence, @@ -161,6 +162,7 @@ def __init__( speculative_config: Optional[SpeculativeConfig], decoding_config: Optional[DecodingConfig], observability_config: Optional[ObservabilityConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], executor_class: Type[ExecutorBase], log_stats: bool, usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, @@ -217,6 +219,7 @@ def __init__( self.speculative_config = speculative_config self.load_config = load_config self.decoding_config = decoding_config or DecodingConfig() + self.prompt_adapter_config = prompt_adapter_config self.observability_config = observability_config or ObservabilityConfig( ) self.log_stats = log_stats @@ -245,6 +248,7 @@ def __init__( vision_language_config=vision_language_config, speculative_config=speculative_config, load_config=load_config, + prompt_adapter_config=prompt_adapter_config, ) if not self.model_config.embedding_mode: @@ -277,6 +281,8 @@ def __init__( # Feature flags "enable_lora": bool(lora_config), + "enable_prompt_adapter": + bool(prompt_adapter_config), "enable_prefix_caching": cache_config.enable_prefix_caching, "enforce_eager": @@ -367,7 +373,6 @@ def from_engine_args( engine_config = engine_args.create_engine_config() distributed_executor_backend = ( engine_config.parallel_config.distributed_executor_backend) - # Initialize the cluster and specify the executor class. if engine_config.device_config.device_type == "neuron": from vllm.executor.neuron_executor import NeuronExecutor @@ -400,7 +405,6 @@ def from_engine_args( else: from vllm.executor.gpu_executor import GPUExecutor executor_class = GPUExecutor - # Create the LLM engine. engine = cls( **engine_config.to_dict(), @@ -461,6 +465,9 @@ def _verify_args(self) -> None: self.lora_config.verify_with_model_config(self.model_config) self.lora_config.verify_with_scheduler_config( self.scheduler_config) + if self.prompt_adapter_config: + self.prompt_adapter_config.verify_with_model_config( + self.model_config) def _get_eos_token_id( self, lora_request: Optional[LoRARequest]) -> Optional[int]: @@ -478,6 +485,7 @@ def _add_processed_request( params: Union[SamplingParams, PoolingParams], arrival_time: float, lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], trace_headers: Optional[Dict[str, str]] = None, ) -> None: # Create the sequences. @@ -486,7 +494,7 @@ def _add_processed_request( eos_token_id = self._get_eos_token_id(lora_request) seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, - lora_request) + lora_request, prompt_adapter_request) # Create a SequenceGroup based on SamplingParams or PoolingParams if isinstance(params, SamplingParams): @@ -497,7 +505,7 @@ def _add_processed_request( arrival_time=arrival_time, lora_request=lora_request, trace_headers=trace_headers, - ) + prompt_adapter_request=prompt_adapter_request) elif isinstance(params, PoolingParams): seq_group = self._create_sequence_group_with_pooling( request_id, @@ -505,7 +513,7 @@ def _add_processed_request( params, arrival_time=arrival_time, lora_request=lora_request, - ) + prompt_adapter_request=prompt_adapter_request) else: raise ValueError( "Either SamplingParams or PoolingParams must be provided.") @@ -518,6 +526,7 @@ def process_model_inputs( request_id: str, inputs: PromptInputs, lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> LLMInputs: if isinstance(inputs, str): inputs = {"prompt": inputs} @@ -532,12 +541,18 @@ def process_model_inputs( else: prompt_token_ids = inputs["prompt_token_ids"] + if prompt_adapter_request: + prompt_token_ids = \ + [0] * prompt_adapter_request.prompt_adapter_num_virtual_tokens\ + + prompt_token_ids + llm_inputs = LLMInputs(prompt_token_ids=prompt_token_ids, prompt=inputs.get("prompt"), multi_modal_data=inputs.get("multi_modal_data")) return self.input_processor(llm_inputs) + def add_request( self, request_id: str, @@ -546,6 +561,7 @@ def add_request( arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Dict[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: """Add a request to the engine's request pool. @@ -595,9 +611,11 @@ def add_request( if arrival_time is None: arrival_time = time.time() - processed_inputs = self.process_model_inputs(request_id=request_id, - inputs=inputs, - lora_request=lora_request) + processed_inputs = self.process_model_inputs( + request_id=request_id, + inputs=inputs, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) self._add_processed_request( request_id=request_id, @@ -605,6 +623,7 @@ def add_request( params=params, arrival_time=arrival_time, lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, trace_headers=trace_headers, ) @@ -616,6 +635,7 @@ def _create_sequence_group_with_sampling( arrival_time: float, lora_request: Optional[LoRARequest], trace_headers: Optional[Dict[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs @@ -641,7 +661,7 @@ def _create_sequence_group_with_sampling( sampling_params=sampling_params, lora_request=lora_request, trace_headers=trace_headers, - ) + prompt_adapter_request=prompt_adapter_request) return seq_group @@ -652,16 +672,19 @@ def _create_sequence_group_with_pooling( pooling_params: PoolingParams, arrival_time: float, lora_request: Optional[LoRARequest], + prompt_adapter_request: Optional[PromptAdapterRequest], ) -> SequenceGroup: """Creates a SequenceGroup with PoolingParams.""" # Defensive copy of PoolingParams, which are used by the pooler pooling_params = pooling_params.clone() # Create the sequence group. - seq_group = SequenceGroup(request_id=request_id, - seqs=[seq], - arrival_time=arrival_time, - lora_request=lora_request, - pooling_params=pooling_params) + seq_group = SequenceGroup( + request_id=request_id, + seqs=[seq], + arrival_time=arrival_time, + lora_request=lora_request, + pooling_params=pooling_params, + prompt_adapter_request=prompt_adapter_request) return seq_group def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: @@ -1042,6 +1065,16 @@ def list_loras(self) -> Set[int]: def pin_lora(self, lora_id: int) -> bool: return self.model_executor.pin_lora(lora_id) + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return self.model_executor.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_executor.remove_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> List[int]: + return self.model_executor.list_prompt_adapters() + def check_health(self) -> None: if self.tokenizer: self.tokenizer.check_health() diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 9e923493160e..8e684a214ab0 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -13,6 +13,7 @@ from vllm.lora.request import LoRARequest from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.transformers_utils.tokenizer import get_cached_tokenizer from vllm.usage.usage_lib import UsageContext @@ -250,6 +251,7 @@ def generate( prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -299,7 +301,7 @@ def generate( inputs=inputs, params=sampling_params, lora_request=lora_request, - ) + prompt_adapter_request=prompt_adapter_request) outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, RequestOutput) @@ -392,6 +394,7 @@ def encode( prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None, use_tqdm: bool = True, lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> List[EmbeddingRequestOutput]: """Generates the completions for the input prompts. @@ -440,6 +443,7 @@ def encode( inputs=inputs, params=pooling_params, lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, ) outputs = self._run_engine(use_tqdm=use_tqdm) @@ -499,6 +503,7 @@ def _validate_and_add_requests( params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams, Sequence[PoolingParams]], lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]], + prompt_adapter_request: Optional[PromptAdapterRequest], ) -> None: if isinstance(inputs, (str, dict)): # Convert a single prompt to a list. @@ -521,19 +526,23 @@ def _validate_and_add_requests( params[i] if isinstance(params, Sequence) else params, lora_request=lora_request[i] if isinstance( lora_request, Sequence) else lora_request, - ) + prompt_adapter_request=prompt_adapter_request) def _add_request( - self, - inputs: PromptInputs, - params: Union[SamplingParams, PoolingParams], - lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + self, + inputs: PromptInputs, + params: Union[SamplingParams, PoolingParams], + lora_request: Optional[Union[List[LoRARequest], + LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> None: request_id = str(next(self.request_counter)) - self.llm_engine.add_request(request_id, - inputs, - params, - lora_request=lora_request) + self.llm_engine.add_request( + request_id, + inputs, + params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) def _run_engine( self, *, use_tqdm: bool diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 6137cecd881d..0b2507480be9 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -7,6 +7,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) @@ -48,6 +49,7 @@ def _init_worker(self): lora_config=self.lora_config, vision_language_config=self.vision_language_config, kv_cache_dtype=self.cache_config.cache_dtype, + prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=True, ) self.driver_worker.init_device() @@ -90,6 +92,19 @@ def pin_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return self.driver_worker.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.driver_worker.remove_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + return self.driver_worker.list_prompt_adapters() + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.driver_worker.pin_prompt_adapter(prompt_adapter_id) + def check_health(self) -> None: # CPUExecutor will always be healthy as long as # it's running. diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index d7c19622e270..17f6beb146a9 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -2,9 +2,11 @@ from typing import List, Optional, Set, Tuple from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, VisionLanguageConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, SamplerOutput @@ -16,18 +18,14 @@ class ExecutorBase(ABC): that can execute the model on multiple devices. """ - def __init__( - self, - model_config: ModelConfig, - cache_config: CacheConfig, - parallel_config: ParallelConfig, - scheduler_config: SchedulerConfig, - device_config: DeviceConfig, - load_config: LoadConfig, - lora_config: Optional[LoRAConfig], - vision_language_config: Optional[VisionLanguageConfig], - speculative_config: Optional[SpeculativeConfig], - ) -> None: + def __init__(self, model_config: ModelConfig, cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, load_config: LoadConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + speculative_config: Optional[SpeculativeConfig], + prompt_adapter_config: Optional[PromptAdapterConfig]) -> None: self.model_config = model_config self.cache_config = cache_config self.lora_config = lora_config @@ -37,6 +35,7 @@ def __init__( self.device_config = device_config self.vision_language_config = vision_language_config self.speculative_config = speculative_config + self.prompt_adapter_config = prompt_adapter_config self._init_executor() @@ -94,6 +93,23 @@ def pin_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: raise NotImplementedError + @abstractmethod + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + raise NotImplementedError + + @abstractmethod + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError + + @abstractmethod + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + raise NotImplementedError # type: ignore + + @abstractmethod + def list_prompt_adapters(self) -> Set[int]: + raise NotImplementedError + @abstractmethod def check_health(self) -> None: """Checks if the executor is healthy. If not, it should raise an diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 5522b5322e66..bb73a2f92d64 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -3,6 +3,7 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) @@ -45,6 +46,7 @@ def _get_worker_kwargs( lora_config=self.lora_config, vision_language_config=self.vision_language_config, speculative_config=self.speculative_config, + prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=rank == 0, ) @@ -106,6 +108,25 @@ def pin_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self.driver_worker.list_loras() + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + assert prompt_adapter_request.prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return self.driver_worker.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return self.driver_worker.remove_prompt_adapter(prompt_adapter_id) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + assert prompt_adapter_id > 0, \ + "prompt_adapter_id must be greater than 0." + return self.driver_worker.pin_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + return self.driver_worker.list_prompt_adapters() + def check_health(self) -> None: # GPUExecutor will always be healthy as long as # it's running. diff --git a/vllm/executor/ray_xpu_executor.py b/vllm/executor/ray_xpu_executor.py index dd7c82289341..ebd38300cd1e 100644 --- a/vllm/executor/ray_xpu_executor.py +++ b/vllm/executor/ray_xpu_executor.py @@ -7,8 +7,9 @@ Tuple, Union) from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, VisionLanguageConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) from vllm.executor.distributed_gpu_executor import ( # yapf: disable DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.ray_utils import RayWorkerWrapper, ray @@ -44,6 +45,7 @@ def __init__( load_config: LoadConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], speculative_config: Optional[SpeculativeConfig], ) -> None: assert device_config.device_type == "xpu" @@ -58,6 +60,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.vision_language_config = vision_language_config + self.prompt_adapter_config = prompt_adapter_config placement_group = self.parallel_config.placement_group diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py index d37200bd02de..9fc6033510bc 100644 --- a/vllm/executor/xpu_executor.py +++ b/vllm/executor/xpu_executor.py @@ -3,8 +3,9 @@ import torch from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, VisionLanguageConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger @@ -27,6 +28,7 @@ def __init__( load_config: LoadConfig, lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], + prompt_adapter_config: Optional[PromptAdapterConfig], speculative_config: Optional[SpeculativeConfig], ) -> None: assert device_config.device_type == "xpu" @@ -43,6 +45,7 @@ def __init__( self.scheduler_config = scheduler_config self.device_config = device_config self.vision_language_config = vision_language_config + self.prompt_adapter_config = prompt_adapter_config, self.speculative_config = None # Instantiate the worker and load the model to GPU. diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 2fddfccaf1e4..44b5e6fd5576 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -8,6 +8,7 @@ import torch.nn.functional as F from transformers import PretrainedConfig +from vllm.adapter_commons.layers import AdapterMapping from vllm.config import LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -134,15 +135,8 @@ def _apply_lora_packed_nslice( @dataclass -class LoRAMapping: - # Per every token in input_ids: - index_mapping: Tuple[int, ...] - # Per sampled token: - prompt_mapping: Tuple[int, ...] - - def __post_init__(self): - self.index_mapping = tuple(self.index_mapping) - self.prompt_mapping = tuple(self.prompt_mapping) +class LoRAMapping(AdapterMapping): + pass class BaseLayerWithLoRA(nn.Module): diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 689835def83d..e1ede7d4d710 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -4,12 +4,17 @@ import os import re from dataclasses import dataclass, field -from typing import Callable, Dict, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union import safetensors.torch import torch from torch import nn +from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, + AdapterModelManager) +from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter, + get_adapter, list_adapters, + remove_adapter, set_adapter_mapping) from vllm.config import LoRAConfig from vllm.logger import init_logger from vllm.lora.layers import (BaseLayerWithLoRA, @@ -19,7 +24,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor, parse_fine_tuned_lora_name, replace_submodule) from vllm.model_executor.models.interfaces import SupportsLoRA -from vllm.utils import LRUCache, is_pin_memory_available +from vllm.utils import is_pin_memory_available logger = init_logger(__name__) @@ -153,7 +158,7 @@ def get_lora_id(): return _GLOBAL_LORA_ID -class LoRAModel: +class LoRAModel(AdapterModel): """A LoRA fine-tuned model.""" def __init__( @@ -388,7 +393,7 @@ def from_local_checkpoint( ) -class LoRAModelManager: +class LoRAModelManager(AdapterModelManager): """A manager that manages multiple LoRA-fine-tuned models.""" def __init__( @@ -440,8 +445,7 @@ def __init__( # base_indices, sampler_indices, sampler_indices_padded, # embeddings_indices self.indices_len: List[Optional[int]] = [None] * 4 - - self.model = model + super().__init__(model) if hasattr(self.model, "supported_lora_modules"): self.supported_lora_modules = copy.deepcopy( self.model.supported_lora_modules) @@ -453,11 +457,11 @@ def __init__( self.model.packed_modules_mapping) self.packed_modules: Dict[str, List[str]] = {} self.modules: Dict[str, "BaseLayerWithLoRA"] = {} - self._registered_loras: Dict[int, LoRAModel] = {} # Dict instead of a Set for compatibility with LRUCache. - self._active_loras: Dict[int, None] = {} self._last_mapping: Optional[LoRAMapping] = None self._create_lora_modules() + self.model.lora_manager = self + self.adapter_type = 'LoRa' @property def capacity(self) -> int: @@ -467,15 +471,16 @@ def capacity(self) -> int: def lora_slots(self) -> int: return self.lora_config.max_loras - def __len__(self) -> int: - return len(self._registered_loras) + @property + def adapter_slots(self) -> int: + return self.lora_slots - def activate_lora( + def activate_adapter( self, lora_id: int, ) -> bool: """Move LoRA into a GPU buffer to be used in the forward pass.""" - if lora_id in self._active_loras: + if lora_id in self._active_adapters: return False first_free_slot = next( ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id) @@ -483,8 +488,8 @@ def activate_lora( if first_free_slot is None: raise ValueError("No free lora slots") index, _ = first_free_slot - self._active_loras[lora_id] = None - lora_model = self._registered_loras[lora_id] + self._active_adapters[lora_id] = None + lora_model = self._registered_adapters[lora_id] logger.debug("Activating LoRA. int id: %d, slot index: %d", lora_model.id, index) self.lora_index_to_id[index] = lora_model.id @@ -498,21 +503,13 @@ def activate_lora( module.reset_lora(index) return True - def _deactivate_lora(self, lora_id: int): + def _deactivate_adapter(self, lora_id: int): try: index = self.lora_index_to_id.index(lora_id) self.lora_index_to_id[index] = None except ValueError: pass - def deactivate_lora(self, lora_id: int) -> bool: - """Remove a LoRA from a GPU buffer.""" - if lora_id in self._active_loras: - self._deactivate_lora(lora_id) - self._active_loras.pop(lora_id) - return True - return False - def _set_long_lora_context(self, lora: LoRAModel): if self.long_lora_context is None: return @@ -528,40 +525,19 @@ def _set_long_lora_context(self, lora: LoRAModel): if offsets: self.long_lora_context.offsets_by_lora_id[lora.id] = offsets - def _add_lora(self, lora: LoRAModel): + def _add_adapter(self, lora: LoRAModel): self._create_merged_loras_inplace(lora) - self._registered_loras[lora.id] = lora + self._registered_adapters[lora.id] = lora self._set_long_lora_context(lora) - def add_lora(self, lora: LoRAModel) -> bool: - """Add a LoRAModel to the manager CPU cache.""" - logger.debug( - "Adding lora. Model id: %d, " - "int id: %d, " - "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) - if lora.id not in self._registered_loras: - if len(self._registered_loras) >= self.capacity: - raise RuntimeError("No free LoRA slots.") - self._add_lora(lora) - return True - return False - - def remove_lora(self, lora_id: int) -> bool: - """Remove a LoRAModel from the manager CPU cache.""" - # TODO: should we check active lora? - self.deactivate_lora(lora_id) - if self.long_lora_context: - self.long_lora_context.offsets_by_lora_id.pop(lora_id, None) - return bool(self._registered_loras.pop(lora_id, None)) - - def pin_lora(self, lora_id: int) -> bool: + def pin_adapter(self, lora_id: int) -> bool: """Pin a LoRAModel in the manager cache.""" raise NotImplementedError( "Pinning is not supported in LoRAModelManager." "Use LRUCacheLoRAModelManager for pinning") # type: ignore # TODO see if this can be vectorized - def _set_lora_mapping(self, mapping: LoRAMapping) -> None: + def _set_adapter_mapping(self, mapping: LoRAMapping) -> None: (base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, long_lora_offsets_tensor, indices_len) = convert_mapping(mapping, self.lora_index_to_id, @@ -583,23 +559,11 @@ def _set_lora_mapping(self, mapping: LoRAMapping) -> None: # Maintain the reference self.indices_len[:] = indices_len - def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None: - if self._last_mapping != lora_mapping: - self._set_lora_mapping(lora_mapping) - self._last_mapping = lora_mapping - - def list_loras(self) -> Dict[int, LoRAModel]: - """List all registered LoRAModels.""" - return dict(self._registered_loras) - - def get_lora(self, lora_id: int) -> Optional[LoRAModel]: - return self._registered_loras.get(lora_id, None) - - def remove_all_loras(self): + def remove_all_adapters(self): """Remove all LoRAModels from the manager.""" - self._registered_loras.clear() + self._registered_adapters.clear() self.lora_index_to_id = [None] * self.lora_slots - self._active_loras.clear() + self._active_adapters.clear() def _create_lora_modules(self): for module_name, module in self.model.named_modules( @@ -743,18 +707,39 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: lora_model.loras[module_name] = PackedLoRALayerWeights.pack( replacement_loras) + def deactivate_adapter(self, adapter_id: int) -> bool: + return deactivate_adapter(adapter_id, self._active_adapters, + self._deactivate_adapter) + + def add_adapter(self, adapter: LoRAModel) -> bool: + logger.debug( + "Adding lora. Model id: %d, " + "int id: %d, " + "scaling factor: %s", adapter.id, adapter.id, + adapter.scaling_factor) + return add_adapter(adapter, self._registered_adapters, self.capacity, + self._add_adapter) -class LoRALRUCache(LRUCache[LoRAModel]): + def set_adapter_mapping(self, mapping: LoRAMapping) -> None: + self._last_mapping = set_adapter_mapping(mapping, self._last_mapping, + self._set_adapter_mapping) + + def remove_adapter(self, adapter_id: int) -> bool: + return remove_adapter(adapter_id, self._registered_adapters, + self.deactivate_adapter) + + def list_adapters(self) -> Dict[int, Any]: + return list_adapters(self._registered_adapters) + + def get_adapter(self, adapter_id: int) -> Optional[Any]: + return get_adapter(adapter_id, self._registered_adapters) + + +class LoRALRUCache(AdapterLRUCache[LoRAModel]): def __init__(self, capacity: int, deactivate_lora_fn: Callable[[int], bool]): - super().__init__(capacity) - self.deactivate_lora_fn = deactivate_lora_fn - - def _on_remove(self, key: int, value: LoRAModel): - logger.debug("Removing LoRA. int id: %d", key) - self.deactivate_lora_fn(key) - return super()._on_remove(key, value) + super().__init__(capacity, deactivate_lora_fn) class LRUCacheLoRAModelManager(LoRAModelManager): @@ -770,49 +755,49 @@ def __init__( ): super().__init__(model, max_num_seqs, max_num_batched_tokens, vocab_size, lora_config) - self._registered_loras: LoRALRUCache = LoRALRUCache( - self.capacity, self.deactivate_lora) - self._active_loras: LoRALRUCache = LoRALRUCache( - self.lora_slots, self._deactivate_lora) + self._registered_adapters: LoRALRUCache = LoRALRUCache( + self.capacity, self.deactivate_adapter) + self._active_adapters: LoRALRUCache = LoRALRUCache( + self.lora_slots, self._deactivate_adapter) - def list_loras(self) -> Dict[int, LoRAModel]: + def list_adapters(self) -> Dict[int, LoRAModel]: """List all registered LoRAModels.""" - return dict(self._registered_loras.cache) + return dict(self._registered_adapters.cache) - def add_lora(self, lora: LoRAModel) -> bool: + def add_adapter(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager.""" logger.debug( "Adding lora. Model id: %d, " "int id: %d, " "scaling factor: %s", lora.id, lora.id, lora.scaling_factor) - if lora.id not in self._registered_loras: - self._add_lora(lora) + if lora.id not in self._registered_adapters: + self._add_adapter(lora) was_added = True else: # We always touch to update the LRU cache order - self._registered_loras.touch(lora.id) + self._registered_adapters.touch(lora.id) was_added = False return was_added - def activate_lora( + def activate_adapter( self, lora_id: int, ) -> bool: - if lora_id not in self._active_loras and len( - self._active_loras) >= self.lora_slots: - self._active_loras.remove_oldest() - result = super().activate_lora(lora_id) + if lora_id not in self._active_adapters and len( + self._active_adapters) >= self.lora_slots: + self._active_adapters.remove_oldest() + result = super().activate_adapter(lora_id) # We always touch to update the LRU cache order - self._active_loras.touch(lora_id) + self._active_adapters.touch(lora_id) return result - def remove_oldest_lora(self) -> bool: - if len(self._registered_loras) > 0: - self._registered_loras.remove_oldest() + def remove_oldest_adapter(self) -> bool: + if len(self._registered_adapters) > 0: + self._registered_adapters.remove_oldest() return True return False - def pin_lora(self, lora_id: int) -> bool: + def pin_adapter(self, lora_id: int) -> bool: """Pin a LoRAModel in the manager cache.""" self._pin_lora_in_cpu_cache(lora_id) self._pin_lora_in_gpu_cache(lora_id) @@ -820,17 +805,17 @@ def pin_lora(self, lora_id: int) -> bool: def _pin_lora_in_cpu_cache(self, lora_id: int): try: - self._registered_loras.pin(lora_id) + self._registered_adapters.pin(lora_id) except ValueError as err: raise ValueError("Pinning failed. " f"LoRA {lora_id} is not registered.") from err def _pin_lora_in_gpu_cache(self, lora_id: int): - if lora_id not in self._active_loras: + if lora_id not in self._active_adapters: # move lora to gpu if not already active - self.activate_lora(lora_id) + self.activate_adapter(lora_id) - self._active_loras.pin(lora_id) + self._active_adapters.pin(lora_id) def create_lora_manager( diff --git a/vllm/lora/request.py b/vllm/lora/request.py index 662774ffe09a..2d10d037760e 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -1,13 +1,15 @@ from dataclasses import dataclass from typing import Optional +from vllm.adapter_commons.request import AdapterRequest + @dataclass -class LoRARequest: +class LoRARequest(AdapterRequest): """ Request for a LoRA adapter. - Note that this class should be be used internally. For online + Note that this class should be used internally. For online serving, it is recommended to not allow users to use this class but instead provide another layer of abstraction to prevent users from accessing unauthorized LoRA adapters. @@ -20,15 +22,16 @@ class LoRARequest: lora_int_id: int lora_local_path: str long_lora_max_len: Optional[int] = None + __hash__ = AdapterRequest.__hash__ - def __post_init__(self): - if self.lora_int_id < 1: - raise ValueError( - f"lora_int_id must be > 0, got {self.lora_int_id}") + @property + def adapter_id(self): + return self.lora_int_id - def __eq__(self, value: object) -> bool: - return isinstance( - value, LoRARequest) and self.lora_int_id == value.lora_int_id + @property + def name(self): + return self.lora_name - def __hash__(self) -> int: - return self.lora_int_id + @property + def local_path(self): + return self.lora_local_path diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index ca4903c23bca..3d0ef4252b02 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -1,12 +1,15 @@ -from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Any, Dict, List, Literal, Optional, Set, Type, Union import torch +from vllm.adapter_commons.utils import (add_adapter_worker, + apply_adapters_worker, + list_adapters_worker, + set_active_adapters_worker) +from vllm.adapter_commons.worker_manager import AbstractWorkerManager from vllm.config import LoRAConfig from vllm.logger import init_logger -from vllm.lora.layers import LoRAMapping from vllm.lora.models import (LoRAModel, LoRAModelManager, LRUCacheLoRAModelManager, create_lora_manager) from vllm.lora.request import LoRARequest @@ -14,79 +17,13 @@ logger = init_logger(__name__) -class AbstractWorkerLoRAManager(ABC): - """Abstract class for managing LoRA models on the worker side.""" - - def __init__(self, - max_num_seqs: int, - max_num_batched_tokens: int, - vocab_size: int, - lora_config: LoRAConfig, - device: torch.device, - max_position_embeddings: Optional[int] = None): - self.max_num_seqs = max_num_seqs - self.max_num_batched_tokens = max_num_batched_tokens - self.max_position_embeddings = max_position_embeddings - self.vocab_size = vocab_size - self.device = device - self.lora_config = lora_config - - # If False, do not cache. If None, cache is empty. - self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False - - @contextmanager - def dummy_lora_cache(self): - """Use this context manager to reuse the dummy lora model - to avoid creating it repeatedly.""" - self._cached_dummy_lora = None - yield - self._cached_dummy_lora = False - - @property - @abstractmethod - def is_enabled(self) -> bool: - ... - - @abstractmethod - def create_lora_manager( - self, - model: torch.nn.Module, - ) -> Any: - ... - - @abstractmethod - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - ... - - @abstractmethod - def add_lora(self, lora_request: LoRARequest) -> bool: - ... - - @abstractmethod - def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: - ... - - @abstractmethod - def remove_lora(self, lora_id: int) -> bool: - ... - - @abstractmethod - def remove_all_loras(self): - ... - - @abstractmethod - def list_loras(self) -> Set[int]: - ... - - -class WorkerLoRAManager(AbstractWorkerLoRAManager): +class WorkerLoRAManager(AbstractWorkerManager): """WorkerLoRAManager that manages LoRA models on the worker side. Every request, the requested LoRAs will be loaded (unless they are already loaded), and every other LoRA will be unloaded.""" - _lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager + _manager_cls: Type[LoRAModelManager] = LoRAModelManager def __init__( self, @@ -103,16 +40,23 @@ def __init__( self._lora_model_cls = lora_model_cls self.embedding_modules = embedding_modules self.embedding_padding_modules = embedding_padding_modules + self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = max_num_batched_tokens + self.vocab_size = vocab_size + self.lora_config = lora_config + self.max_position_embeddings = max_position_embeddings + super().__init__(device) # Lazily initialized by create_lora_manager. - self._lora_manager: LoRAModelManager - super().__init__( - max_num_seqs, - max_num_batched_tokens, - vocab_size, - lora_config, - device, - max_position_embeddings=max_position_embeddings, - ) + self._adapter_manager: LoRAModelManager + + @contextmanager + def dummy_lora_cache(self): + """Use this context manager to reuse the dummy lora model + to avoid creating it repeatedly.""" + self._cached_dummy_lora = None + yield + self._cached_dummy_lora = False @property def is_enabled(self) -> bool: @@ -128,41 +72,14 @@ def create_lora_manager( max_num_batched_tokens=self.max_num_batched_tokens, vocab_size=self.vocab_size, lora_config=self.lora_config, - lora_manager_cls=self._lora_manager_cls, + lora_manager_cls=self._manager_cls, ) - self._lora_manager = lora_manager + self._adapter_manager = lora_manager return lora_manager.model - def set_active_loras(self, lora_requests: Set[LoRARequest], - lora_mapping: LoRAMapping) -> None: - self._apply_loras(lora_requests) - self._lora_manager.set_lora_mapping(lora_mapping) - - def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None: - loras_that_exist = self.list_loras() - loras_map = { - lora_request.lora_int_id: lora_request - for lora_request in lora_requests if lora_request - } - if len(loras_map) > self._lora_manager.lora_slots: - raise RuntimeError( - f"Number of requested LoRAs ({len(loras_map)}) is greater " - "than the number of GPU LoRA slots " - f"({self._lora_manager.lora_slots}).") - - new_loras = set(loras_map) - loras_to_add = new_loras - loras_that_exist - loras_to_remove = loras_that_exist - new_loras - - for lora_id in loras_to_remove: - self.remove_lora(lora_id) - - for lora_id in loras_to_add: - self.add_lora(loras_map[lora_id]) - - def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: + def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: try: - model = self._lora_manager.model + model = self._adapter_manager.model supported_lora_modules = model.supported_lora_modules packed_modules_mapping = model.packed_modules_mapping expected_lora_modules: List[str] = [] @@ -198,37 +115,45 @@ def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: return lora def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: - if lora_request.lora_int_id in self.list_loras(): + if lora_request.lora_int_id in self.list_adapters(): return False if isinstance(self._cached_dummy_lora, LoRAModel): dummy_lora = self._cached_dummy_lora.clone( lora_request.lora_int_id) else: - dummy_lora = self._lora_manager.create_dummy_lora( + dummy_lora = self._adapter_manager.create_dummy_lora( lora_request.lora_int_id, rank, 1, self.embedding_modules) if self._cached_dummy_lora is None: self._cached_dummy_lora = dummy_lora - return self._lora_manager.add_lora(dummy_lora) + return self._adapter_manager.add_adapter(dummy_lora) - def add_lora(self, lora_request: LoRARequest) -> bool: - if lora_request.lora_int_id in self.list_loras(): - return False - lora = self._load_lora(lora_request) - loaded = self._lora_manager.add_lora(lora) - self._lora_manager.activate_lora(lora.id) - return loaded + def pin_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.pin_adapter(adapter_id) + + def set_active_adapters(self, requests: Set[Any], + mapping: Optional[Any]) -> None: + set_active_adapters_worker(requests, mapping, self._apply_adapters, + self._adapter_manager.set_adapter_mapping) + + def _apply_adapters(self, adapter_requests: Set[Any]) -> None: + apply_adapters_worker(adapter_requests, self.list_adapters, + self._adapter_manager.adapter_slots, + self.remove_adapter, self.add_adapter) - def remove_lora(self, lora_id: int) -> bool: - return self._lora_manager.remove_lora(lora_id) + def add_adapter(self, adapter_request: Any) -> bool: + return add_adapter_worker(adapter_request, self.list_adapters, + self._load_adapter, + self._adapter_manager.add_adapter, + self._adapter_manager.activate_adapter) - def pin_lora(self, lora_id: int) -> bool: - return self._lora_manager.pin_lora(lora_id) + def remove_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.remove_adapter(adapter_id) - def remove_all_loras(self): - self._lora_manager.remove_all_loras() + def remove_all_adapters(self): + self._adapter_manager.remove_all_adapters() - def list_loras(self) -> Set[int]: - return set(self._lora_manager.list_loras()) + def list_adapters(self) -> Set[int]: + return list_adapters_worker(self._adapter_manager.list_adapters) class LRUCacheWorkerLoRAManager(WorkerLoRAManager): @@ -238,8 +163,7 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): (unless they are already loaded) and least recently used LoRAs will be unloaded if the cache is above capacity.""" - _lora_manager_cls: Type[ - LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager + _manager_cls: Type[LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager def create_lora_manager( self, @@ -247,40 +171,41 @@ def create_lora_manager( ) -> Any: lora_manager = create_lora_manager( model, - lora_manager_cls=self._lora_manager_cls, + lora_manager_cls=self._manager_cls, max_num_seqs=self.max_num_seqs, vocab_size=self.vocab_size, lora_config=self.lora_config, max_num_batched_tokens=self.max_num_batched_tokens, ) - self._lora_manager = lora_manager + self._adapter_manager = lora_manager return lora_manager.model - def _apply_loras(self, lora_requests: Set[LoRARequest]) -> None: + def _apply_adapters(self, lora_requests: Set[LoRARequest]) -> None: loras_map = { lora_request.lora_int_id: lora_request for lora_request in lora_requests if lora_request } - if len(loras_map) > self._lora_manager.lora_slots: + if len(loras_map) > self._adapter_manager.lora_slots: raise RuntimeError( f"Number of requested LoRAs ({len(loras_map)}) is greater " "than the number of GPU LoRA slots " - f"({self._lora_manager.lora_slots}).") + f"({self._adapter_manager.lora_slots}).") for lora in loras_map.values(): - self.add_lora(lora) + self.add_adapter(lora) - def add_lora(self, lora_request: LoRARequest) -> bool: - if lora_request.lora_int_id not in self.list_loras(): + def add_adapter(self, lora_request: LoRARequest) -> bool: + if lora_request.lora_int_id not in self.list_adapters(): # Remove before we load the new lora to save memory - if len(self._lora_manager) + 1 > self._lora_manager.capacity: - assert isinstance(self._lora_manager, LRUCacheLoRAModelManager) - self._lora_manager.remove_oldest_lora() - lora = self._load_lora(lora_request) - loaded = self._lora_manager.add_lora(lora) + if len(self._adapter_manager) + 1 > self._adapter_manager.capacity: + assert isinstance(self._adapter_manager, + LRUCacheLoRAModelManager) + self._adapter_manager.remove_oldest_adapter() + lora = self._load_adapter(lora_request) + loaded = self._adapter_manager.add_adapter(lora) else: # If the lora is already loaded, just touch it to # update its position in the caches - loaded = self._lora_manager.get_lora( + loaded = self._adapter_manager.get_adapter( lora_request.lora_int_id) is not None - self._lora_manager.activate_lora(lora_request.lora_int_id) + self._adapter_manager.activate_adapter(lora_request.lora_int_id) return loaded diff --git a/vllm/prompt_adapter/__init__.py b/vllm/prompt_adapter/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/vllm/prompt_adapter/layers.py b/vllm/prompt_adapter/layers.py new file mode 100644 index 000000000000..07aa015d8257 --- /dev/null +++ b/vllm/prompt_adapter/layers.py @@ -0,0 +1,80 @@ +from dataclasses import dataclass +from typing import Optional + +import torch +from torch import nn + +from vllm.adapter_commons.layers import AdapterMapping +from vllm.config import PromptAdapterConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) + + +@dataclass +class PromptAdapterMapping(AdapterMapping): + pass + + +class VocabParallelEmbeddingWithPromptAdapter(nn.Module): + + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + self.emb_layer = self.base_layer + if 'LoRA' in base_layer.__class__.__name__: + self.emb_layer = self.base_layer.base_layer + + def create_prompt_adapter_weights( + self, prompt_adapter_config: PromptAdapterConfig): + self.embeddings_tensors = torch.zeros( + ( + prompt_adapter_config.max_prompt_adapters, + prompt_adapter_config.max_prompt_adapter_token, + self.emb_layer.embedding_dim, + ), + dtype=self.emb_layer.weight.dtype, + device=self.emb_layer.weight.device, + ) + self.adapter_lengths = torch.zeros( + prompt_adapter_config.max_prompt_adapters, + dtype=torch.long, + device=self.emb_layer.weight.device) + + self.indices_gpu: torch.Tensor + self.embedding_indices_gpu: torch.Tensor + + def reset_prompt_adapter(self, index: int): + self.embeddings_tensors[index] = 0 + + def set_prompt_adapter( + self, + index: int, + adapter_model: Optional[torch.Tensor], + ): + self.reset_prompt_adapter(index) + if adapter_model is not None: + length = adapter_model.shape[0] + self.embeddings_tensors[index, :length] = adapter_model + self.adapter_lengths[index] = length + + def set_mapping( + self, + prompt_indices: torch.Tensor, + prompt_embedding_indices: torch.Tensor, + ): + self.indices_gpu = prompt_indices.to( + device=self.emb_layer.weight.device) + self.embedding_indices_gpu = prompt_embedding_indices.to( + device=self.emb_layer.weight.device) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + hidden_states = self.base_layer(x) + if self.embedding_indices_gpu.numel(): + valid_mask = self.indices_gpu != -1 + gathered_embeddings = self.embeddings_tensors[ + self.embedding_indices_gpu[:, 0], + self.embedding_indices_gpu[:, 1]] + + # Update hidden states + hidden_states[valid_mask] = gathered_embeddings + return hidden_states diff --git a/vllm/prompt_adapter/models.py b/vllm/prompt_adapter/models.py new file mode 100644 index 000000000000..acd878dc9b9a --- /dev/null +++ b/vllm/prompt_adapter/models.py @@ -0,0 +1,340 @@ +import logging +import math +from typing import Any, Callable, Dict, List, Optional, Type + +import torch +from torch import nn + +from vllm.adapter_commons.models import (AdapterLRUCache, AdapterModel, + AdapterModelManager) +from vllm.adapter_commons.utils import (add_adapter, deactivate_adapter, + get_adapter, list_adapters, + remove_adapter, set_adapter_mapping) +from vllm.config import PromptAdapterConfig +from vllm.prompt_adapter.layers import ( + PromptAdapterMapping, VocabParallelEmbeddingWithPromptAdapter) + +logger = logging.getLogger(__name__) + +_GLOBAL_PROMPT_ADAPTER_ID = 0 + + +def get_prompt_adapter_id(): + global _GLOBAL_PROMPT_ADAPTER_ID + _GLOBAL_PROMPT_ADAPTER_ID += 1 + return _GLOBAL_PROMPT_ADAPTER_ID + + +def convert_to_embedding_indices(indices): + embedding_indices = [] + count = 0 + + for value in indices: + if value == -1: + count = 0 + else: + embedding_indices.append([value, count]) + count += 1 + + return torch.tensor(embedding_indices) + + +def convert_mapping( + mapping: PromptAdapterMapping, + prompt_adapter_index_to_id: List[Optional[int]], +) -> torch.Tensor: + """Converts PromptAdapterMapping to index tensors. + + Args: + mapping: PromptAdapterMapping mapping rows in a + batch to PromptAdapter ids. + prompt_adapter_index_to_id: List mapping PromptAdapter + ids to PromptAdapter indices. + + Returns: + pa_indices: Tensor of shape [batch_size] mapping batch rows to + PromptAdapter indices. + """ + id_to_index = { + id_: idx + for idx, id_ in enumerate(prompt_adapter_index_to_id) + if id_ is not None + } + pa_indices = torch.tensor([ + id_to_index.get(id_, -1) if id_ > 0 else -1 + for id_ in mapping.index_mapping + ]) + + pa_embedding_mapping = convert_to_embedding_indices(pa_indices) + return pa_indices, pa_embedding_mapping + + +class PromptAdapterModel(AdapterModel): + + def __init__(self, + prompt_adapter_id=None, + num_virtual_tokens=None, + prompt_embedding=None) -> None: + self.id = prompt_adapter_id + self.prompt_embedding = prompt_embedding + self.num_virtual_tokens = num_virtual_tokens + + @classmethod + def from_local_checkpoint( + cls, + adapter_model_path: str, + prompt_adapter_id: int, + device: str = "cuda", + dtype: Optional[torch.dtype] = None) -> "PromptAdapterModel": + from peft.utils import load_peft_weights + + adapters_weights = load_peft_weights(adapter_model_path, device) + prompt_embedding = adapters_weights["prompt_embeddings"].to(dtype) + num_virtual_tokens = prompt_embedding.shape[0] + return cls(prompt_adapter_id, num_virtual_tokens, prompt_embedding) + + +class PromptAdapterModelManager(AdapterModelManager): + """A manager that manages multiple Prompt Adapter models.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + prompt_adapter_config: PromptAdapterConfig, + ): + """Create a PromptAdapterModel and adapter for a given model. + + Args: + model: the model to be adapted. + """ + self.model: nn.Module = model + # Dict instead of a Set for compatibility with LRUCache. + self.prompt_adapter_index_to_id: List[ + Optional[int]] = [None] * self.prompt_adapter_slots + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 + self.prompt_adapter_config = prompt_adapter_config + self.model.prompt_adapter_manager = self + self.adapter_type = 'PromptAdapter' + + self.base_indices = torch.tensor([-1]) + self.base_embedding_indices = torch.tensor([-1]) + + self.modules: Dict[str, nn.Module] = {} + self._create_prompt_adapter_modules() + self._last_mapping: Optional[PromptAdapterMapping] = None + + @property + def prompt_adapter_slots(self) -> int: + return self.prompt_adapter_config.max_prompt_adapters + + @property + def adapter_slots(self) -> int: + return self.prompt_adapter_slots + + @property + def capacity(self) -> int: + return self.prompt_adapter_config.max_cpu_prompt_adapters + + def activate_adapter( + self, + prompt_adapter_id: int, + ) -> bool: + """Move PromptAdapter into a GPU buffer + to be used in the forward pass.""" + if prompt_adapter_id in self._active_adapters: + return False + first_free_slot = next( + ((i, prompt_adapter_id) for i, prompt_adapter_id in enumerate( + self.prompt_adapter_index_to_id) if prompt_adapter_id is None), + None) + if first_free_slot is None: + raise ValueError("No free prompt_adapter slots") + index, _ = first_free_slot + self._active_adapters[prompt_adapter_id] = None + prompt_adapter_model = (self._registered_adapters[prompt_adapter_id]) + logger.debug("Activating prompt_adapter. int id: %d, slot index: %d", + prompt_adapter_model.id, index) + self.prompt_adapter_index_to_id[index] = prompt_adapter_model.id + for _, v in self.modules.items(): + v.set_prompt_adapter(index, prompt_adapter_model.prompt_embedding) + return True + + def _deactivate_adapter(self, prompt_adapter_id: int): + try: + index = self.prompt_adapter_index_to_id.index(prompt_adapter_id) + self.prompt_adapter_index_to_id[index] = None + for _, v in self.modules.items(): + v.reset_prompt_adapter(index) + except ValueError: + pass + + def _add_adapter(self, prompt_adapter: PromptAdapterModel): + self._registered_adapters[prompt_adapter.id] = prompt_adapter + + def _set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: + base_indices, base_embedding_indices = convert_mapping( + mapping, self.prompt_adapter_index_to_id) + for k, v in self.modules.items(): + v.set_mapping(base_indices, base_embedding_indices) + + def _create_prompt_adapter_modules(self): + for module_name, module in self.model.named_modules( + remove_duplicate=False): + if "VocabParallel" in module.__class__.__name__: + new_module = VocabParallelEmbeddingWithPromptAdapter(module) + new_module.create_prompt_adapter_weights( + self.prompt_adapter_config) + replaced_module = self.replace_submodule( + self.model, module_name, new_module) + self.register_module(module.__class__.__name__, + replaced_module) + replaced_module.set_mapping(self.base_indices, + self.base_embedding_indices) + break + + def replace_submodule(self, model: nn.Module, module_name: str, + new_module: nn.Module) -> nn.Module: + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + setattr(parent, target_name, new_module) + return new_module + + def register_module(self, module_name: str, module: nn.Module): + self.modules[module_name] = module + + def pin_adapter(self, prompt_adapter_id: int) -> bool: + """Pin a PromptAdapterModel in the manager cache.""" + raise NotImplementedError( + "Pinning is not supported in PromptAdapterModelManager." + "Use LRUCachePromptAdapterModelManager for pinning" + ) # type: ignore + + def remove_all_adapters(self): + """Remove all PromptAdapterModel from the manager.""" + self._registered_adapters.clear() + self.prompt_adapter_index_to_id = [None] * self.prompt_adapter_slots + self._active_adapters.clear() + + def deactivate_adapter(self, adapter_id: int) -> bool: + return deactivate_adapter(adapter_id, self._active_adapters, + self._deactivate_adapter) + + def add_adapter(self, adapter: PromptAdapterModel) -> bool: + return add_adapter(adapter, self._registered_adapters, self.capacity, + self._add_adapter) + + def set_adapter_mapping(self, mapping: PromptAdapterMapping) -> None: + self._last_mapping = set_adapter_mapping(mapping, self._last_mapping, + self._set_adapter_mapping) + + def remove_adapter(self, adapter_id: int) -> bool: + return remove_adapter(adapter_id, self._registered_adapters, + self.deactivate_adapter) + + def list_adapters(self) -> Dict[int, Any]: + return list_adapters(self._registered_adapters) + + def get_adapter(self, adapter_id: int) -> Optional[Any]: + return get_adapter(adapter_id, self._registered_adapters) + + +class PromptAdapterLRUCache(AdapterLRUCache[PromptAdapterModel]): + + def __init__(self, capacity: int, + deactivate_prompt_adapter_fn: Callable[[int], bool]): + super().__init__(capacity, deactivate_prompt_adapter_fn) + + +class LRUCachePromptAdapterModelManager(PromptAdapterModelManager): + """A model manager that manages multiple prompt_adapters with LRU cache.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + prompt_adapter_config: PromptAdapterConfig, + ): + self.prompt_adapter_config = prompt_adapter_config + super().__init__(model, max_num_seqs, max_num_batched_tokens, + prompt_adapter_config) + self._registered_adapters = PromptAdapterLRUCache( + self.capacity, self.deactivate_adapter) + self._active_adapters = PromptAdapterLRUCache( + self.prompt_adapter_slots, self._deactivate_adapter) + + def list_adapters(self) -> Dict[int, PromptAdapterModel]: + """List all registered PromptAdapterModel.""" + return dict(self._registered_adapters.cache) + + def add_adapter(self, prompt_adapter: PromptAdapterModel) -> bool: + """Add a PromptAdapterModel to the manager.""" + if prompt_adapter.id not in self._registered_adapters: + self._add_adapter(prompt_adapter) + was_added = True + else: + # We always touch to update the LRU cache order + self._registered_adapters.touch(prompt_adapter.id) + was_added = False + return was_added + + def activate_adapter( + self, + prompt_adapter_id: int, + ) -> bool: + if prompt_adapter_id not in self._active_adapters and len( + self._active_adapters) >= self.prompt_adapter_slots: + self._active_adapters.remove_oldest() + result = super().activate_adapter(prompt_adapter_id) + # We always touch to update the LRU cache order + self._active_adapters.touch(prompt_adapter_id) + return result + + def remove_oldest_adapter(self) -> bool: + if len(self._registered_adapters) > 0: + self._registered_adapters.remove_oldest() + return True + return False + + def pin_adapter(self, prompt_adapter_id: int) -> bool: + """Pin a PromptAdapterModel in the manager cache.""" + self._pin_prompt_adapter_in_cpu_cache(prompt_adapter_id) + self._pin_prompt_adapter_in_gpu_cache(prompt_adapter_id) + return True + + def _pin_prompt_adapter_in_cpu_cache(self, prompt_adapter_id: int): + try: + self._registered_adapters.pin(prompt_adapter_id) + except ValueError as err: + raise ValueError( + "Pinning failed. " + f"Prompt Adapter {prompt_adapter_id} is not registered." + ) from err + + def _pin_prompt_adapter_in_gpu_cache(self, prompt_adapter_id: int): + if prompt_adapter_id not in self._active_adapters: + # move adapter to gpu if not already active + self.activate_adapter(prompt_adapter_id) + self._active_adapters.pin(prompt_adapter_id) + + +def create_prompt_adapter_manager( + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + prompt_adapter_config: PromptAdapterConfig, + prompt_adapter_manager_cls: Type[ + PromptAdapterModelManager] = PromptAdapterModelManager, + **kwargs) -> PromptAdapterModelManager: + """Create a PromptAdapterModel for a given model.""" + prompt_adapter_manager = prompt_adapter_manager_cls( + model=model, + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + prompt_adapter_config=prompt_adapter_config, + **kwargs) + return prompt_adapter_manager diff --git a/vllm/prompt_adapter/request.py b/vllm/prompt_adapter/request.py new file mode 100644 index 000000000000..c0c98cf72bba --- /dev/null +++ b/vllm/prompt_adapter/request.py @@ -0,0 +1,30 @@ +from dataclasses import dataclass + +from vllm.adapter_commons.request import AdapterRequest + + +@dataclass +class PromptAdapterRequest(AdapterRequest): + """ + Request for a Prompt adapter. + """ + + prompt_adapter_name: str + prompt_adapter_id: int + prompt_adapter_local_path: str + prompt_adapter_num_virtual_tokens: int + + def __hash__(self): + return super().__hash__() + + @property + def adapter_id(self): + return self.prompt_adapter_id + + @property + def name(self): + return self.prompt_adapter_name + + @property + def local_path(self): + return self.prompt_adapter_local_path diff --git a/vllm/prompt_adapter/worker_manager.py b/vllm/prompt_adapter/worker_manager.py new file mode 100644 index 000000000000..ab72e2ba8316 --- /dev/null +++ b/vllm/prompt_adapter/worker_manager.py @@ -0,0 +1,173 @@ +import logging +from typing import Any, Optional, Set, Type + +import torch + +from vllm.adapter_commons.utils import (add_adapter_worker, + apply_adapters_worker, + list_adapters_worker, + set_active_adapters_worker) +from vllm.adapter_commons.worker_manager import AbstractWorkerManager +from vllm.config import PromptAdapterConfig +from vllm.prompt_adapter.models import (LRUCachePromptAdapterModelManager, + PromptAdapterModel, + PromptAdapterModelManager, + create_prompt_adapter_manager) +from vllm.prompt_adapter.request import PromptAdapterRequest + +logger = logging.getLogger(__name__) + + +class WorkerPromptAdapterManager(AbstractWorkerManager): + """WorkerPromptAdapterManager that manages + prompt_adapter models on the worker side. + + Every request, the requested prompt_adapters will be + loaded (unless they are already loaded), + and every other prompt_adapter will be unloaded.""" + + _manager_cls: Type[PromptAdapterModelManager] = PromptAdapterModelManager + + def __init__( + self, + max_num_seqs: int, + max_num_batched_tokens: int, + device: torch.device, + prompt_adapter_config: PromptAdapterConfig, + prompt_adapter_model_cls: Type[PromptAdapterModel] = PromptAdapterModel + ): + self._adapter_manager: PromptAdapterModelManager + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = max_num_batched_tokens + self._prompt_adapter_model_cls = prompt_adapter_model_cls + self.prompt_adapter_config = prompt_adapter_config + super().__init__(device) + + @property + def is_enabled(self) -> bool: + return True + + def create_prompt_adapter_manager( + self, + model: torch.nn.Module, + ) -> Any: + prompt_adapter_manager = create_prompt_adapter_manager( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + prompt_adapter_config=self.prompt_adapter_config, + prompt_adapter_manager_cls=self._manager_cls, + ) + self._adapter_manager = prompt_adapter_manager + return prompt_adapter_manager.model + + def _load_adapter( + self, prompt_adapter_request: PromptAdapterRequest + ) -> PromptAdapterModel: + try: + prompt_adapter = ( + self._prompt_adapter_model_cls.from_local_checkpoint( + prompt_adapter_request.prompt_adapter_local_path, + prompt_adapter_id=prompt_adapter_request.prompt_adapter_id, + device=str(self.device), + dtype=self.prompt_adapter_config.prompt_adapter_dtype)) + except Exception as e: + raise RuntimeError( + f"Loading prompt_adapter " + f"{prompt_adapter_request.prompt_adapter_local_path}" + f" failed") from e + return prompt_adapter + + def add_dummy_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return True + + def pin_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.pin_adapter(adapter_id) + + def set_active_adapters(self, requests: Set[Any], + mapping: Optional[Any]) -> None: + set_active_adapters_worker(requests, mapping, self._apply_adapters, + self._adapter_manager.set_adapter_mapping) + + def add_adapter(self, adapter_request: Any) -> bool: + return add_adapter_worker(adapter_request, self.list_adapters, + self._load_adapter, + self._adapter_manager.add_adapter, + self._adapter_manager.activate_adapter) + + def _apply_adapters(self, adapter_requests: Set[Any]) -> None: + apply_adapters_worker(adapter_requests, self.list_adapters, + self._adapter_manager.adapter_slots, + self.remove_adapter, self.add_adapter) + + def remove_adapter(self, adapter_id: int) -> bool: + return self._adapter_manager.remove_adapter(adapter_id) + + def remove_all_adapters(self): + self._adapter_manager.remove_all_adapters() + + def list_adapters(self) -> Set[int]: + return list_adapters_worker(self._adapter_manager.list_adapters) + + +class LRUCacheWorkerPromptAdapterManager(WorkerPromptAdapterManager): + """WorkerPromptAdapterManager that manages + prompt_adapter models on the worker side. + + Uses an LRU Cache. Every request, the requested + prompt_adapters will be loaded (unless they are already loaded) + and least recently used prompt_adapters will + be unloaded if the cache is above capacity.""" + + _prompt_adapter_manager_cls: Type[ + LRUCachePromptAdapterModelManager] = LRUCachePromptAdapterModelManager + + def create_prompt_adapter_manager( + self, + model: torch.nn.Module, + ) -> Any: + prompt_adapter_manager = create_prompt_adapter_manager( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + prompt_adapter_config=self.prompt_adapter_config, + prompt_adapter_manager_cls=self._prompt_adapter_manager_cls) + self._adapter_manager: LRUCachePromptAdapterModelManager = ( + prompt_adapter_manager) + return prompt_adapter_manager.model + + def _apply_adapters( + self, prompt_adapter_requests: Set[PromptAdapterRequest]) -> None: + prompt_adapters_map = { + prompt_adapter_request.prompt_adapter_id: prompt_adapter_request + for prompt_adapter_request in prompt_adapter_requests + if prompt_adapter_request + } + if len(prompt_adapters_map + ) > self._adapter_manager.prompt_adapter_slots: + raise RuntimeError( + f"Number of requested prompt_adapters " + f"({len(prompt_adapters_map)}) is greater " + "than the number of GPU prompt_adapter slots " + f"({self._adapter_manager.prompt_adapter_slots}).") + for prompt_adapter in prompt_adapters_map.values(): + self.add_adapter(prompt_adapter) + + def add_adapter(self, + prompt_adapter_request: PromptAdapterRequest) -> bool: + if prompt_adapter_request.prompt_adapter_id not in self.list_adapters( + ): + # Remove before we load the new prompt_adapter to save memory + if len(self._adapter_manager) + 1 > self._adapter_manager.capacity: + self._adapter_manager.remove_oldest_adapter() + prompt_adapter = self._load_adapter(prompt_adapter_request) + loaded = self._adapter_manager.add_adapter(prompt_adapter) + else: + # If the prompt_adapter is already loaded, just touch it to + # update its position in the caches + loaded = self._adapter_manager.get_adapter( + prompt_adapter_request.prompt_adapter_id) is not None + self._adapter_manager.activate_adapter( + prompt_adapter_request.prompt_adapter_id) + return loaded diff --git a/vllm/sequence.py b/vllm/sequence.py index 22cb26dc08ef..bfcc9d59302f 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -10,6 +10,7 @@ from vllm.lora.request import LoRARequest from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams if TYPE_CHECKING: @@ -222,12 +223,14 @@ def __init__( block_size: int, eos_token_id: Optional[int] = None, lora_request: Optional[LoRARequest] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None ) -> None: self.seq_id = seq_id self.inputs = inputs self.block_size = block_size self.eos_token_id = eos_token_id self.lora_request = lora_request + self.prompt_adapter_request = prompt_adapter_request self.data = SequenceData(self.prompt_token_ids) self.output_logprobs: SampleLogprobs = [] @@ -262,6 +265,11 @@ def multi_modal_data(self) -> Optional["MultiModalData"]: def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + @property + def prompt_adapter_id(self) -> int: + return self.prompt_adapter_request.prompt_adapter_id \ + if self.prompt_adapter_request else 0 + def get_output_text_to_return(self, buffer_length: int): # We return the full output text if the sequence is finished. truncate = buffer_length and not self.is_finished() @@ -402,6 +410,7 @@ def __init__( pooling_params: Optional[PoolingParams] = None, encoder_seq: Optional[Sequence] = None, trace_headers: Optional[Dict[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} @@ -416,6 +425,7 @@ def __init__( self.state = SequenceGroupState() self.embeddings = embeddings self.pooling_params = pooling_params + self.prompt_adapter_request = prompt_adapter_request self.encoder_seq = encoder_seq self.trace_headers = trace_headers @@ -441,6 +451,16 @@ def multi_modal_data(self) -> Optional["MultiModalData"]: def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + @property + def prompt_adapter_id(self) -> int: + return self.prompt_adapter_request.prompt_adapter_id \ + if self.prompt_adapter_request else 0 + + @property + def prompt_adapter_num_virtual_tokens(self) -> int: + return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens\ + if self.prompt_adapter_request else 0 + def get_last_latency(self, now: float) -> Optional[float]: """Sets the last token time for Request level timings.""" # If still in prefill phase, raise Error. @@ -617,6 +637,7 @@ def __init__( multi_modal_data: Optional["MultiModalData"] = None, encoder_seq_data: Optional[SequenceData] = None, cross_block_table: Optional[List[int]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: self.request_id = request_id self.is_prompt = is_prompt @@ -625,6 +646,7 @@ def __init__( self.block_tables = block_tables self.pooling_params = pooling_params self.lora_request = lora_request + self.prompt_adapter_request = prompt_adapter_request self.computed_block_nums = computed_block_nums self.multi_modal_data = multi_modal_data self.state = SequenceGroupState() if state is None else state @@ -649,6 +671,16 @@ def __init__( def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 + @property + def prompt_adapter_id(self) -> int: + return self.prompt_adapter_request.prompt_adapter_id \ + if self.prompt_adapter_request else 0 + + @property + def prompt_adapter_num_virtual_tokens(self) -> int: + return self.prompt_adapter_request.prompt_adapter_num_virtual_tokens \ + if self.prompt_adapter_request else 0 + @property def token_chunk_size(self) -> int: """Return the number of tokens to be processed (chunk size).""" diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index f30d29376121..b65e1288b1c3 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -3,8 +3,8 @@ import torch from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, VisionLanguageConfig) from vllm.logger import init_logger from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, @@ -47,6 +47,7 @@ def __init__( kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, vision_language_config: Optional[VisionLanguageConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, return_hidden_states: bool = False, ): if return_hidden_states: @@ -65,6 +66,7 @@ def __init__( kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, vision_language_config=vision_language_config, + prompt_adapter_config=prompt_adapter_config, return_hidden_states=return_hidden_states, ) @@ -130,6 +132,13 @@ def execute_model( self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + outputs: List[SamplerOutput] = [] for step in range(num_steps): # Currently cuda graph is only supported by the decode phase. diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index b83cc6f095bf..63a7eda2d8c9 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -7,8 +7,8 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, VisionLanguageConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.model_loader import get_model @@ -79,6 +79,7 @@ def __init__( lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], kv_cache_dtype: Optional[str] = "auto", + prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, *args, **kwargs, @@ -91,6 +92,7 @@ def __init__( self.device_config = device_config self.cache_config = cache_config self.lora_config = lora_config + self.prompt_adapter_config = prompt_adapter_config self.vision_language_config = vision_language_config self.load_config = load_config self.is_driver_worker = is_driver_worker @@ -126,7 +128,8 @@ def load_model(self) -> None: lora_config=self.lora_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, - cache_config=self.cache_config) + cache_config=self.cache_config, + ) def _prepare_prompt( self, diff --git a/vllm/worker/cpu_worker.py b/vllm/worker/cpu_worker.py index 30ee262c7a8b..df3175797efe 100644 --- a/vllm/worker/cpu_worker.py +++ b/vllm/worker/cpu_worker.py @@ -6,8 +6,8 @@ from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, VisionLanguageConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger @@ -133,6 +133,7 @@ def __init__( lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, kv_cache_dtype: Optional[str] = "auto", + prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, ) -> None: self.model_config = model_config @@ -146,6 +147,7 @@ def __init__( self.distributed_init_method = distributed_init_method self.lora_config = lora_config self.vision_language_config = vision_language_config + self.prompt_adapter_config = prompt_adapter_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." @@ -164,6 +166,7 @@ def __init__( lora_config=self.lora_config, vision_language_config=self.vision_language_config, kv_cache_dtype=kv_cache_dtype, + prompt_adapter_config=self.prompt_adapter_config, is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by # initialize_cache. diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 272917c7272d..e68a3577137d 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -4,8 +4,8 @@ import torch from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, VisionLanguageConfig) from vllm.logger import init_logger from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.pooling_params import PoolingParams @@ -40,6 +40,7 @@ def __init__( kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, vision_language_config: Optional[VisionLanguageConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, ): super().__init__(model_config, parallel_config, @@ -50,7 +51,8 @@ def __init__( lora_config=lora_config, kv_cache_dtype=kv_cache_dtype, is_driver_worker=is_driver_worker, - vision_language_config=vision_language_config) + vision_language_config=vision_language_config, + prompt_adapter_config=prompt_adapter_config) @torch.inference_mode() def execute_model( @@ -69,6 +71,13 @@ def execute_model( self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + # Currently cuda graph is only supported by the decode phase. assert model_input.attn_metadata is not None prefill_meta = model_input.attn_metadata.prefill_metadata diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 942063677a42..df5c6acd4731 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -23,8 +23,8 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, VisionLanguageConfig) from vllm.distributed.parallel_state import graph_capture from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger @@ -36,6 +36,10 @@ from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models.interfaces import supports_lora from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.prompt_adapter.layers import PromptAdapterMapping +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.prompt_adapter.worker_manager import ( + LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, get_kv_cache_torch_dtype, is_hip, @@ -81,6 +85,8 @@ class ModelInputForGPU(ModelRunnerInputBase): lora_requests: Optional[Set[LoRARequest]] = None attn_metadata: Optional["AttentionMetadata"] = None multi_modal_kwargs: Optional[Dict[str, torch.Tensor]] = None + prompt_adapter_mapping: Optional[PromptAdapterMapping] = None + prompt_adapter_requests: Optional[Set[PromptAdapterRequest]] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -89,6 +95,8 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, + "prompt_adapter_mapping": self.prompt_adapter_mapping, + "prompt_adapter_requests": self.prompt_adapter_requests, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) return tensor_dict @@ -122,6 +130,8 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "lora_requests": self.lora_requests, "lora_mapping": self.lora_mapping, "multi_modal_kwargs": self.multi_modal_kwargs, + "prompt_adapter_mapping": self.prompt_adapter_mapping, + "prompt_adapter_requests": self.prompt_adapter_requests, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, @@ -159,6 +169,7 @@ def __init__( kv_cache_dtype: Optional[str] = "auto", is_driver_worker: bool = False, vision_language_config: Optional[VisionLanguageConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, return_hidden_states: bool = False, ): self.model_config = model_config @@ -170,6 +181,7 @@ def __init__( self.load_config = load_config self.is_driver_worker = is_driver_worker self.vision_language_config = vision_language_config + self.prompt_adapter_config = prompt_adapter_config self.return_hidden_states = return_hidden_states self.device = self.device_config.device @@ -211,6 +223,7 @@ def __init__( self.model: nn.Module # Set after load_model # Set after load_model. self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None + self.prompt_adapter_manager: LRUCacheWorkerPromptAdapterManager = None self.flashinfer_decode_workspace_buffer = None self.flashinfer_decode_wrapper = None @@ -227,8 +240,7 @@ def load_model(self) -> None: vision_language_config=self.vision_language_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, - cache_config=self.cache_config, - ) + cache_config=self.cache_config) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", @@ -250,6 +262,15 @@ def load_model(self) -> None: ) self.model = self.lora_manager.create_lora_manager(self.model) + if self.prompt_adapter_config: + self.prompt_adapter_manager = LRUCacheWorkerPromptAdapterManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, self.device, + self.prompt_adapter_config) + self.model = ( + self.prompt_adapter_manager.create_prompt_adapter_manager( + self.model)) + if self.kv_cache_dtype == "fp8" and is_hip(): # Currently only ROCm accepts kv-cache scaling factors # via quantization_param_path and this will be deprecated @@ -329,6 +350,9 @@ def _prepare_model_input_tensors( lora_index_mapping: List[int] = [] lora_prompt_mapping: List[int] = [] lora_requests: Set[LoRARequest] = set() + prompt_adapter_index_mapping: List[int] = [] + prompt_adapter_prompt_mapping: List[int] = [] + prompt_adapter_requests: Set[PromptAdapterRequest] = set() seq_lens: List[int] = [] prefill_seq_lens: List[int] = [] @@ -479,6 +503,7 @@ def _prepare_model_input_tensors( input_tokens.extend(tokens) input_positions.extend(list(range(context_len, seq_len))) lora_id = seq_group_metadata.lora_int_id + prompt_adapter_id = seq_group_metadata.prompt_adapter_id if is_prompt: assert len(seq_ids) == 1 @@ -510,6 +535,21 @@ def _prepare_model_input_tensors( for k, v in mm_kwargs.items(): multi_modal_kwargs_list[k].append(v) + if prompt_adapter_id > 0: + prompt_adapter_requests.add( + seq_group_metadata.prompt_adapter_request) + + num_tokens = seq_group_metadata.\ + prompt_adapter_num_virtual_tokens + pm = [prompt_adapter_id + ] * num_tokens + [0] * (query_len - num_tokens) + prompt_adapter_index_mapping += pm + prompt_adapter_prompt_mapping.extend( + [prompt_adapter_id] * + (query_len if seq_group_metadata.sampling_params + and seq_group_metadata.sampling_params.prompt_logprobs + else 1)) + is_profile_run = _is_block_tables_empty( seq_group_metadata.block_tables) if is_profile_run: @@ -594,12 +634,11 @@ def _prepare_model_input_tensors( seq_lens.append(1) block_tables.append([]) lora_index_mapping.append(0) - + prompt_adapter_index_mapping.append(0) if self.attn_backend.get_name() == "flashinfer": last_paged_kv_indptr = paged_kv_indptr[-1] paged_kv_indptr.append(last_paged_kv_indptr) paged_kv_last_page_len.append(0) - batch_size = graph_batch_size num_decode_tokens = batch_size @@ -725,6 +764,14 @@ def _prepare_model_input_tensors( else: lora_mapping = None + if self.prompt_adapter_config: + prompt_adapter_mapping = PromptAdapterMapping( + prompt_adapter_index_mapping, + prompt_adapter_prompt_mapping, + ) + else: + prompt_adapter_mapping = None + multi_modal_kwargs = { k: torch.cat(v, dim=0).to(self.device) for k, v in multi_modal_kwargs_list.items() @@ -739,6 +786,8 @@ def _prepare_model_input_tensors( lora_mapping=lora_mapping, lora_requests=lora_requests, multi_modal_kwargs=multi_modal_kwargs, + prompt_adapter_mapping=prompt_adapter_mapping, + prompt_adapter_requests=prompt_adapter_requests, ) @torch.inference_mode() @@ -818,33 +867,67 @@ def profile_run(self) -> None: def remove_all_loras(self): if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - self.lora_manager.remove_all_loras() + self.lora_manager.remove_all_adapters() def set_active_loras(self, lora_requests: Set[LoRARequest], lora_mapping: LoRAMapping) -> None: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - self.lora_manager.set_active_loras(lora_requests, lora_mapping) + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) def add_lora(self, lora_request: LoRARequest) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.add_lora(lora_request) + return self.lora_manager.add_adapter(lora_request) def remove_lora(self, lora_id: int) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.remove_lora(lora_id) + return self.lora_manager.remove_adapter(lora_id) def pin_lora(self, lora_id: int) -> bool: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.pin_lora(lora_id) + return self.lora_manager.pin_adapter(lora_id) def list_loras(self) -> Set[int]: if not self.lora_manager: raise RuntimeError("LoRA is not enabled.") - return self.lora_manager.list_loras() + return self.lora_manager.list_adapters() + + def remove_all_prompt_adapters(self): + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + self.prompt_adapter_manager.remove_all_adapters() + + def set_active_prompt_adapters( + self, prompt_adapter_requests: Set[PromptAdapterRequest], + prompt_adapter_mapping: PromptAdapterMapping) -> None: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + self.prompt_adapter_manager.set_active_adapters( + prompt_adapter_requests, prompt_adapter_mapping) + + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.add_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.remove_adapter(prompt_adapter_id) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.pin_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + if not self.prompt_adapter_manager: + raise RuntimeError("PromptAdapter is not enabled.") + return self.prompt_adapter_manager.list_adapters() @torch.inference_mode() def capture_model(self, kv_caches: List[torch.Tensor]) -> None: @@ -990,6 +1073,14 @@ def capture_model(self, kv_caches: List[torch.Tensor]) -> None: ) self.set_active_loras(set(), lora_mapping) + if self.prompt_adapter_config: + prompt_adapter_mapping = PromptAdapterMapping( + [-1] * batch_size, + [-1] * batch_size, + ) + self.set_active_prompt_adapters(set(), + prompt_adapter_mapping) + graph_runner = CUDAGraphRunner(self.model, self.attn_backend.get_name()) @@ -1090,6 +1181,13 @@ def execute_model( self.set_active_loras(model_input.lora_requests, model_input.lora_mapping) + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + if self.attn_backend.get_name() == "flashinfer": assert model_input.attn_metadata is not None assert model_input.input_tokens is not None diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index cc27d06b511f..808c756ee8a8 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -7,14 +7,16 @@ import torch.distributed from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, VisionLanguageConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.model_executor.model_loader.tensorizer import TensorizerConfig +from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import ExecuteModelRequest from vllm.utils import get_device_capability_stateless from vllm.worker.cache_engine import CacheEngine @@ -45,6 +47,7 @@ def __init__( lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, speculative_config: Optional[SpeculativeConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, ) -> None: @@ -58,6 +61,7 @@ def __init__( self.distributed_init_method = distributed_init_method self.lora_config = lora_config self.load_config = load_config + self.prompt_adapter_config = prompt_adapter_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." @@ -95,6 +99,7 @@ def __init__( kv_cache_dtype=self.cache_config.cache_dtype, is_driver_worker=is_driver_worker, vision_language_config=vision_language_config, + prompt_adapter_config=prompt_adapter_config, **speculative_args, ) # Uninitialized cache engine. Will be initialized by @@ -288,6 +293,19 @@ def pin_lora(self, lora_id: int) -> bool: def list_loras(self) -> Set[int]: return self.model_runner.list_loras() + def add_prompt_adapter( + self, prompt_adapter_request: PromptAdapterRequest) -> bool: + return self.model_runner.add_prompt_adapter(prompt_adapter_request) + + def remove_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_runner.remove_lora(prompt_adapter_id) + + def pin_prompt_adapter(self, prompt_adapter_id: int) -> bool: + return self.model_runner.pin_prompt_adapter(prompt_adapter_id) + + def list_prompt_adapters(self) -> Set[int]: + return self.model_runner.list_prompt_adapters() + @property def max_model_len(self) -> int: return self.model_config.max_model_len diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 99fd7da5edda..81775bee32a5 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -6,8 +6,8 @@ from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - VisionLanguageConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, VisionLanguageConfig) from vllm.distributed import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model @@ -82,6 +82,7 @@ def __init__( lora_config: Optional[LoRAConfig], vision_language_config: Optional[VisionLanguageConfig], kv_cache_dtype: Optional[str] = "auto", + prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, *args, **kwargs, @@ -93,6 +94,7 @@ def __init__( self.load_config = load_config self.cache_config = cache_config self.vision_language_config = vision_language_config + self.prompt_adapter_config = prompt_adapter_config self.is_driver_worker = is_driver_worker self.sliding_window = model_config.get_sliding_window() diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index 773ee9f8159e..ee0336829578 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -9,8 +9,9 @@ import torch.distributed from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, SchedulerConfig, - SpeculativeConfig, VisionLanguageConfig) + ModelConfig, ParallelConfig, PromptAdapterConfig, + SchedulerConfig, SpeculativeConfig, + VisionLanguageConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.logger import init_logger @@ -47,6 +48,7 @@ def __init__( lora_config: Optional[LoRAConfig] = None, vision_language_config: Optional[VisionLanguageConfig] = None, speculative_config: Optional[SpeculativeConfig] = None, + prompt_adapter_config: Optional[PromptAdapterConfig] = None, is_driver_worker: bool = False, ) -> None: assert device_config.device_type == "xpu" @@ -62,6 +64,7 @@ def __init__( self.rank = rank self.distributed_init_method = distributed_init_method self.lora_config = lora_config + self.prompt_adapter_config = prompt_adapter_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0."