diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index b39dce2659a54..0412c5f37952d 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -1,36 +1,43 @@ import os +import sys import zipfile -MAX_SIZE_MB = 250 +# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 250 MB +VLLM_MAX_SIZE_MB = int(os.environ.get('VLLM_MAX_SIZE_MB', 250)) def print_top_10_largest_files(zip_file): + """Print the top 10 largest files in the given zip file.""" with zipfile.ZipFile(zip_file, 'r') as z: file_sizes = [(f, z.getinfo(f).file_size) for f in z.namelist()] file_sizes.sort(key=lambda x: x[1], reverse=True) for f, size in file_sizes[:10]: - print(f"{f}: {size/(1024*1024)} MBs uncompressed.") + print(f"{f}: {size / (1024 * 1024):.2f} MBs uncompressed.") def check_wheel_size(directory): + """Check the size of .whl files in the given directory.""" for root, _, files in os.walk(directory): - for f in files: - if f.endswith(".whl"): - wheel_path = os.path.join(root, f) - wheel_size = os.path.getsize(wheel_path) - wheel_size_mb = wheel_size / (1024 * 1024) - if wheel_size_mb > MAX_SIZE_MB: - print( - f"Wheel {wheel_path} is too large ({wheel_size_mb} MB) " - f"compare to the allowed size ({MAX_SIZE_MB} MB).") + for file_name in files: + if file_name.endswith(".whl"): + wheel_path = os.path.join(root, file_name) + wheel_size_mb = os.path.getsize(wheel_path) / (1024 * 1024) + if wheel_size_mb > VLLM_MAX_SIZE_MB: + print(f"Not allowed: Wheel {wheel_path} is larger " + f"({wheel_size_mb:.2f} MB) than the limit " + f"({VLLM_MAX_SIZE_MB} MB).") print_top_10_largest_files(wheel_path) return 1 else: print(f"Wheel {wheel_path} is within the allowed size " - f"({wheel_size_mb} MB).") + f"({wheel_size_mb:.2f} MB).") return 0 if __name__ == "__main__": - import sys - sys.exit(check_wheel_size(sys.argv[1])) + if len(sys.argv) < 2: + print("Usage: python check-wheel-size.py ") + sys.exit(1) + + directory = sys.argv[1] + sys.exit(check_wheel_size(directory)) \ No newline at end of file diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt index bb9cd43e2df04..064883859218a 100644 --- a/.buildkite/lm-eval-harness/configs/models-small.txt +++ b/.buildkite/lm-eval-harness/configs/models-small.txt @@ -1,5 +1,4 @@ Meta-Llama-3-8B-Instruct.yaml -Meta-Llama-3-8B-Instruct-FP8.yaml Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh old mode 100644 new mode 100755 index 5548071390aff..972c62a091aea --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -1,5 +1,5 @@ # This script runs test inside the corresponding ROCm docker container. -set -ex +set -o pipefail # Print ROCm version echo "--- Confirming Clean Initial State" @@ -70,16 +70,51 @@ HF_CACHE="$(realpath ~)/huggingface" mkdir -p ${HF_CACHE} HF_MOUNT="/root/.cache/huggingface" -docker run \ +commands=$@ +PARALLEL_JOB_COUNT=8 +# check if the command contains shard flag, we will run all shards in parallel because the host have 8 GPUs. +if [[ $commands == *"--shard-id="* ]]; then + for GPU in $(seq 0 $(($PARALLEL_JOB_COUNT-1))); do + #replace shard arguments + commands=${@//"--shard-id= "/"--shard-id=${GPU} "} + commands=${commands//"--num-shards= "/"--num-shards=${PARALLEL_JOB_COUNT} "} + docker run \ --device /dev/kfd --device /dev/dri \ --network host \ --shm-size=16gb \ --rm \ - -e HIP_VISIBLE_DEVICES=0 \ + -e HIP_VISIBLE_DEVICES=${GPU} \ -e HF_TOKEN \ -v ${HF_CACHE}:${HF_MOUNT} \ -e HF_HOME=${HF_MOUNT} \ - --name ${container_name} \ + --name ${container_name}_${GPU} \ ${image_name} \ - /bin/bash -c "${@}" - + /bin/bash -c "${commands}" \ + |& while read -r line; do echo ">>Shard $GPU: $line"; done & + PIDS+=($!) + done + #wait for all processes to finish and collect exit codes + for pid in ${PIDS[@]}; do + wait ${pid} + STATUS+=($?) + done + for st in ${STATUS[@]}; do + if [[ ${st} -ne 0 ]]; then + echo "One of the processes failed with $st" + exit ${st} + fi + done +else + docker run \ + --device /dev/kfd --device /dev/dri \ + --network host \ + --shm-size=16gb \ + --rm \ + -e HIP_VISIBLE_DEVICES=0 \ + -e HF_TOKEN \ + -v ${HF_CACHE}:${HF_MOUNT} \ + -e HF_HOME=${HF_MOUNT} \ + --name ${container_name} \ + ${image_name} \ + /bin/bash -c "${commands}" +fi diff --git a/.buildkite/run-cpu-test-ppc64le.sh b/.buildkite/run-cpu-test-ppc64le.sh new file mode 100755 index 0000000000000..a01cf3fe67489 --- /dev/null +++ b/.buildkite/run-cpu-test-ppc64le.sh @@ -0,0 +1,32 @@ +# This script build the CPU docker image and run the offline inference inside the container. +# It serves a sanity check for compilation and basic model usage. +set -ex + +# Try building the docker image +docker build -t cpu-test -f Dockerfile.ppc64le . + +# Setup cleanup +remove_docker_container() { docker rm -f cpu-test || true; } +trap remove_docker_container EXIT +remove_docker_container + +# Run the image, setting --shm-size=4g for tensor parallel. +#docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test +docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN --name cpu-test cpu-test + +# Run basic model test +docker exec cpu-test bash -c " + pip install pytest matplotlib einops transformers_stream_generator + pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_oot_registration.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported + +# online inference +docker exec cpu-test bash -c " + python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m & + timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 + python3 benchmarks/benchmark_serving.py \ + --backend vllm \ + --dataset-name random \ + --model facebook/opt-125m \ + --num-prompts 20 \ + --endpoint /v1/completions \ + --tokenizer facebook/opt-125m" diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index 8e4be08f3aba0..ca9cf15780e25 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -23,7 +23,12 @@ docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py" # Run basic model test docker exec cpu-test bash -c " pip install pytest matplotlib einops transformers_stream_generator - pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_oot_registration.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported + pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py \ + --ignore=tests/models/test_oot_registration.py \ + --ignore=tests/models/test_registry.py \ + --ignore=tests/models/test_fp8.py \ + --ignore=tests/models/test_jamba.py \ + --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported # online inference docker exec cpu-test bash -c " diff --git a/.buildkite/run-tpu-test.sh b/.buildkite/run-tpu-test.sh index 335ffd83fcd7a..6989c94d46a89 100644 --- a/.buildkite/run-tpu-test.sh +++ b/.buildkite/run-tpu-test.sh @@ -12,4 +12,4 @@ remove_docker_container # For HF_TOKEN. source /etc/environment # Run a simple end-to-end example. -docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py" +docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 -m pip install pytest && pytest -v -s /workspace/vllm/tests/tpu/test_custom_dispatcher.py && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 9f449ff650b90..a0c7b7442b3b3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -90,6 +90,8 @@ steps: - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process - pytest -v -s entrypoints/openai + - pytest -v -s entrypoints/test_chat_utils.py + - label: Distributed Tests (4 GPUs) # 10min working_dir: "/vllm-workspace/tests" @@ -156,6 +158,7 @@ steps: - python3 offline_inference_with_prefix.py - python3 llm_engine_example.py - python3 offline_inference_vision_language.py + - python3 offline_inference_vision_language_multi_image.py - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference_encoder_decoder.py @@ -173,6 +176,7 @@ steps: - vllm/ commands: - pytest -v -s ./compile/test_full_graph.py + - pytest -v -s ./compile/test_wrapper.py - label: Vision Language Models Test # 42min @@ -216,9 +220,9 @@ steps: - pytest -v -s spec_decode - label: LoRA Test %N # 30min each + mirror_hardwares: [amd] source_file_dependencies: - vllm/lora - - csrc/punica - tests/lora command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py parallelism: 4 @@ -269,6 +273,15 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - bash ./run-tests.sh -c configs/models-small.txt -t 1 +- label: OpenAI-Compatible Tool Use # 20 min + fast_check: false + mirror_hardwares: [ amd ] + source_file_dependencies: + - vllm/ + - tests/tool_use + commands: + - pytest -v -s tool_use + ##### 1 GPU test ##### ##### multi gpus test ##### @@ -356,9 +369,9 @@ steps: - label: LoRA Long Context (Distributed) # 11min # This test runs llama 13B, so it is required to run on 4 GPUs. num_gpus: 4 + soft_fail: true source_file_dependencies: - vllm/lora - - csrc/punica - tests/lora/test_long_context commands: # FIXIT: find out which code initialize cuda before running the test @@ -373,7 +386,18 @@ steps: - vllm/ - tests/weight_loading commands: - - bash weight_loading/run_model_weight_loading_test.sh + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt + +- label: Weight Loading Multiple GPU Test - Large Models # optional + working_dir: "/vllm-workspace/tests" + num_gpus: 2 + gpu: a100 + optional: true + source_file_dependencies: + - vllm/ + - tests/weight_loading + commands: + - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-large.txt ##### multi gpus test ##### diff --git a/.github/workflows/add_label_ready_comment.yml b/.github/workflows/add_label_ready_comment.yml deleted file mode 100644 index 729c1452af03d..0000000000000 --- a/.github/workflows/add_label_ready_comment.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: Add Ready Label on Ready Comment - -on: - issue_comment: - types: [created] - -jobs: - add-ready-label: - runs-on: ubuntu-latest - if: github.event.issue.pull_request && contains(github.event.comment.body, '/ready') - steps: - - name: Add label - uses: actions/github-script@v5 - with: - script: | - github.rest.issues.addLabels({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: context.issue.number, - labels: ['ready'] - }) - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/reminder_comment.yml b/.github/workflows/reminder_comment.yml index 390c88bb65308..99827756d2066 100644 --- a/.github/workflows/reminder_comment.yml +++ b/.github/workflows/reminder_comment.yml @@ -15,7 +15,7 @@ jobs: owner: context.repo.owner, repo: context.repo.repo, issue_number: context.issue.number, - body: 'šŸ‘‹ Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your `fast-check` build on Buildkite UI. \n\nOnce the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).\n\n To run full CI, you can do one of these:\n- Comment `/ready` on the PR\n- Add `ready` label to the PR\n- Enable auto-merge.\n\nšŸš€' + body: 'šŸ‘‹ Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your `fastcheck` build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping `simon-mo` or `khluu` to add you in our Buildkite org. \n\nOnce the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.\n\n To run CI, PR reviewers can do one of these:\n- Add `ready` label to the PR\n- Enable auto-merge.\n\nšŸš€' }) env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/remove_label_not_ready_comment.yml b/.github/workflows/remove_label_not_ready_comment.yml deleted file mode 100644 index d1da7726eaee3..0000000000000 --- a/.github/workflows/remove_label_not_ready_comment.yml +++ /dev/null @@ -1,23 +0,0 @@ -name: Remove ready Label on notready Comment - -on: - issue_comment: - types: [created] - -jobs: - add-ready-label: - runs-on: ubuntu-latest - if: github.event.issue.pull_request && contains(github.event.comment.body, '/notready') - steps: - - name: Remove ready label - uses: actions/github-script@v5 - with: - script: | - github.rest.issues.removeLabel({ - owner: context.repo.owner, - repo: context.repo.repo, - issue_number: context.issue.number, - name: 'ready' - }) - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/CMakeLists.txt b/CMakeLists.txt index 5b0d0ba904c32..9c88c31c83da1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -181,7 +181,6 @@ set(VLLM_EXT_SRC "csrc/pos_encoding_kernels.cu" "csrc/activation_kernels.cu" "csrc/layernorm_kernels.cu" - "csrc/quantization/squeezellm/quant_cuda_kernel.cu" "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/common.cu" @@ -203,6 +202,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") FetchContent_MakeAvailable(cutlass) list(APPEND VLLM_EXT_SRC + "csrc/mamba/mamba_ssm/selective_scan_fwd.cu" + "csrc/mamba/causal_conv1d/causal_conv1d.cu" "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000..f801b5f8f5513 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ + +# vLLM Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socioeconomic status, +nationality, personal appearance, race, caste, color, religion, or sexual +identity and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the overall + community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or advances of + any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email address, + without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official email address, +posting via an official social media account, or acting as an appointed +representative at an online or offline/IRL event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement in the #code-of-conduct +channel in the [vLLM Discord](https://discord.com/invite/jz7wjKhh6g). +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series of +actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or permanent +ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within the +community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org/), +version 2.1, available at +[v2.1](https://www.contributor-covenant.org/version/2/1/code_of_conduct.html). + +Community Impact Guidelines were inspired by +[Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/inclusion). + +For answers to common questions about this code of conduct, see the +[Contributor Covenant FAQ](https://www.contributor-covenant.org/faq). Translations are available at +[Contributor Covenant translations](https://www.contributor-covenant.org/translations). + diff --git a/Dockerfile b/Dockerfile index 36fcc2f83e9fb..0ec6655ed449e 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,7 +10,7 @@ ARG CUDA_VERSION=12.4.1 # prepare basic build environment FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base ARG CUDA_VERSION=12.4.1 -ARG PYTHON_VERSION=3.10 +ARG PYTHON_VERSION=3.12 ENV DEBIAN_FRONTEND=noninteractive # Install Python and other dependencies @@ -37,14 +37,10 @@ WORKDIR /workspace # install build and runtime dependencies COPY requirements-common.txt requirements-common.txt -COPY requirements-adag.txt requirements-adag.txt COPY requirements-cuda.txt requirements-cuda.txt RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-cuda.txt -COPY requirements-mamba.txt requirements-mamba.txt -RUN python3 -m pip install packaging -RUN python3 -m pip install -r requirements-mamba.txt # cuda arch list used by torch # can be useful for both `dev` and `test` @@ -69,7 +65,6 @@ COPY setup.py setup.py COPY cmake cmake COPY CMakeLists.txt CMakeLists.txt COPY requirements-common.txt requirements-common.txt -COPY requirements-adag.txt requirements-adag.txt COPY requirements-cuda.txt requirements-cuda.txt COPY pyproject.toml pyproject.toml COPY vllm vllm @@ -111,10 +106,17 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ python3 setup.py bdist_wheel --dist-dir=dist --py-limited-api=cp38; \ fi -# check the size of the wheel, we cannot upload wheels larger than 100MB +# Check the size of the wheel if RUN_WHEEL_CHECK is true COPY .buildkite/check-wheel-size.py check-wheel-size.py -RUN python3 check-wheel-size.py dist - +# Default max size of the wheel is 250MB +ARG VLLM_MAX_SIZE_MB=250 +ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB +ARG RUN_WHEEL_CHECK=true +RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \ + python3 check-wheel-size.py dist; \ + else \ + echo "Skipping wheel size check."; \ + fi #################### EXTENSION Build IMAGE #################### #################### DEV IMAGE #################### @@ -127,27 +129,11 @@ RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-dev.txt #################### DEV IMAGE #################### -#################### MAMBA Build IMAGE #################### -FROM dev as mamba-builder -# max jobs used for build -ARG max_jobs=2 -ENV MAX_JOBS=${max_jobs} - -WORKDIR /usr/src/mamba - -COPY requirements-mamba.txt requirements-mamba.txt - -# Download the wheel or build it if a pre-compiled release doesn't exist -RUN pip --verbose wheel -r requirements-mamba.txt \ - --no-build-isolation --no-deps --no-cache-dir - -#################### MAMBA Build IMAGE #################### - #################### vLLM installation IMAGE #################### # image with vLLM installed FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu20.04 AS vllm-base ARG CUDA_VERSION=12.4.1 -ARG PYTHON_VERSION=3.10 +ARG PYTHON_VERSION=3.12 WORKDIR /vllm-workspace ENV DEBIAN_FRONTEND=noninteractive @@ -179,13 +165,9 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install dist/*.whl --verbose -RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamba \ - --mount=type=cache,target=/root/.cache/pip \ - python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir - RUN --mount=type=cache,target=/root/.cache/pip \ . /etc/environment && \ - python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.4/flashinfer-0.1.4+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl + python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.6/flashinfer-0.1.6+cu121torch2.4-cp${PYTHON_VERSION_STR}-cp${PYTHON_VERSION_STR}-linux_x86_64.whl #################### vLLM installation IMAGE #################### @@ -197,6 +179,10 @@ FROM vllm-base AS test ADD . /vllm-workspace/ # install development dependencies (for testing) +# A newer setuptools is required for installing some test dependencies from source that do not publish python 3.12 wheels +# This installation must complete before the test dependencies are collected and installed. +RUN --mount=type=cache,target=/root/.cache/pip \ + python3 -m pip install "setuptools>=74.1.1" RUN --mount=type=cache,target=/root/.cache/pip \ python3 -m pip install -r requirements-dev.txt diff --git a/Dockerfile.ppc64le b/Dockerfile.ppc64le index d4e4c483cada8..16780f8ab950c 100644 --- a/Dockerfile.ppc64le +++ b/Dockerfile.ppc64le @@ -2,21 +2,27 @@ FROM mambaorg/micromamba ARG MAMBA_DOCKERFILE_ACTIVATE=1 USER root -RUN apt-get update -y && apt-get install -y git wget vim numactl gcc-12 g++-12 protobuf-compiler libprotobuf-dev && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 +ENV PATH="/usr/local/cargo/bin:$PATH:/opt/conda/bin/" + +RUN apt-get update -y && apt-get install -y git wget vim libnuma-dev libsndfile-dev libprotobuf-dev build-essential # Some packages in requirements-cpu are installed here # IBM provides optimized packages for ppc64le processors in the open-ce project for mamba # Currently these may not be available for venv or pip directly -RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 pytorch-cpu=2.1.2 torchvision-cpu=0.16.2 && micromamba clean --all --yes +RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 torchvision-cpu=0.16.2 rust && micromamba clean --all --yes COPY ./ /workspace/vllm WORKDIR /workspace/vllm # These packages will be in rocketce eventually -RUN pip install -v -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing +RUN pip install -v cmake torch==2.3.1 uvloop==0.20.0 -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install -WORKDIR /vllm-workspace -ENTRYPOINT ["/opt/conda/bin/python3", "-m", "vllm.entrypoints.openai.api_server"] +WORKDIR /workspace/ + +RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks + +ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] + diff --git a/Dockerfile.tpu b/Dockerfile.tpu index 1cf43247e9781..3a11c6721ead9 100644 --- a/Dockerfile.tpu +++ b/Dockerfile.tpu @@ -1,4 +1,4 @@ -ARG NIGHTLY_DATE="20240808" +ARG NIGHTLY_DATE="20240828" ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE" FROM $BASE_IMAGE diff --git a/MANIFEST.in b/MANIFEST.in index 5a41e5e714184..82be639ef4d73 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,5 +1,4 @@ include LICENSE -include requirements-adag.txt include requirements-common.txt include requirements-cuda.txt include requirements-rocm.txt diff --git a/README.md b/README.md index 9ae30f8d2de55..53749cb36b972 100644 --- a/README.md +++ b/README.md @@ -17,15 +17,16 @@ Easy, fast, and cheap LLM serving for everyone --- -**vLLM & NVIDIA Triton User Meetup (Monday, September 9, 5pm-9pm PT) at Fort Mason, San Francisco** +**vLLM, AMD, Anyscale Meet & Greet at [Ray Summit 2024](http://raysummit.anyscale.com) (Monday, Sept 30th, 5-7pm PT) at Marriott Marquis San Francisco** -We are excited to announce our sixth vLLM Meetup, in collaboration with NVIDIA Triton Team. -Join us to hear the vLLM's recent update about performance. -Register now [here](https://lu.ma/87q3nvnh) and be part of the event! +We are excited to announce our special vLLM event in collaboration with AMD and Anyscale. +Join us to learn more about recent advancements of vLLM on MI300X. +Register [here](https://lu.ma/db5ld9n5) and be a part of the event! --- *Latest News* šŸ”„ +- [2024/09] We hosted [the sixth vLLM meetup](https://lu.ma/87q3nvnh) with NVIDIA! Please find the meetup slides [here](https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing). - [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing). - [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html). - [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing). @@ -130,3 +131,10 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs year={2023} } ``` + +## Contact Us + +* For technical questions and feature requests, please use Github issues or discussions. +* For discussing with fellow users, please use Discord. +* For security disclosures, please use Github's security advisory feature. +* For collaborations and partnerships, please contact us at vllm-questions AT lists.berkeley.edu. \ No newline at end of file diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index f7d67692f697b..3243bb94f787c 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -24,6 +24,7 @@ class RequestFuncInput: model: str best_of: int = 1 use_beam_search: bool = False + logprobs: Optional[int] = None @dataclass @@ -236,6 +237,7 @@ async def async_request_openai_completions( "temperature": 0.0, "best_of": request_func_input.best_of, "max_tokens": request_func_input.output_len, + "logprobs": request_func_input.logprobs, "stream": True, } headers = { diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index fe687da492901..9ba3f649810b7 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -56,20 +56,27 @@ class BenchmarkMetrics: total_input: int total_output: int request_throughput: float - input_throughput: float output_throughput: float + total_token_throughput: float mean_ttft_ms: float median_ttft_ms: float std_ttft_ms: float - p99_ttft_ms: float + percentiles_ttft_ms: List[Tuple[float, float]] mean_tpot_ms: float median_tpot_ms: float std_tpot_ms: float - p99_tpot_ms: float + percentiles_tpot_ms: List[Tuple[float, float]] mean_itl_ms: float median_itl_ms: float std_itl_ms: float - p99_itl_ms: float + percentiles_itl_ms: List[Tuple[float, float]] + # E2EL stands for end-to-end latency per request. + # It is the time taken on the client side from sending + # a request to receiving a complete response. + mean_e2el_ms: float + median_e2el_ms: float + std_e2el_ms: float + percentiles_e2el_ms: List[Tuple[float, float]] def sample_sharegpt_requests( @@ -188,8 +195,16 @@ def sample_sonnet_requests( def sample_random_requests( - input_len: int, output_len: int, num_prompts: int, range_ratio: float, - tokenizer: PreTrainedTokenizerBase) -> List[Tuple[str, int, int]]: + prefix_len: int, + input_len: int, + output_len: int, + num_prompts: int, + range_ratio: float, + tokenizer: PreTrainedTokenizerBase, +) -> List[Tuple[str, int, int]]: + prefix_token_ids = np.random.randint(0, + tokenizer.vocab_size, + size=prefix_len).tolist() input_lens = np.random.randint( int(input_len * range_ratio), @@ -204,10 +219,12 @@ def sample_random_requests( offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) input_requests = [] for i in range(num_prompts): - prompt = tokenizer.decode([(offsets[i] + i + j) % tokenizer.vocab_size + prompt = tokenizer.decode(prefix_token_ids + + [(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])]) + input_requests.append( - (prompt, int(input_lens[i]), int(output_lens[i]))) + (prompt, int(prefix_len + input_lens[i]), int(output_lens[i]))) return input_requests @@ -235,6 +252,8 @@ def calculate_metrics( outputs: List[RequestFuncOutput], dur_s: float, tokenizer: PreTrainedTokenizerBase, + selected_percentile_metrics: List[str], + selected_percentiles: List[float], ) -> Tuple[BenchmarkMetrics, List[int]]: actual_output_lens: List[int] = [] total_input = 0 @@ -242,6 +261,7 @@ def calculate_metrics( itls: List[float] = [] tpots: List[float] = [] ttfts: List[float] = [] + e2els: List[float] = [] for i in range(len(outputs)): if outputs[i].success: # We use the tokenizer to count the number of output tokens for all @@ -258,6 +278,7 @@ def calculate_metrics( (outputs[i].latency - outputs[i].ttft) / (output_len - 1)) itls += outputs[i].itl ttfts.append(outputs[i].ttft) + e2els.append(outputs[i].latency) completed += 1 else: actual_output_lens.append(0) @@ -272,21 +293,29 @@ def calculate_metrics( total_input=total_input, total_output=sum(actual_output_lens), request_throughput=completed / dur_s, - input_throughput=total_input / dur_s, output_throughput=sum(actual_output_lens) / dur_s, + total_token_throughput=(total_input + sum(actual_output_lens)) / dur_s, mean_ttft_ms=np.mean(ttfts or 0) * 1000, # ttfts is empty if streaming is not supported by backend - median_ttft_ms=np.median(ttfts or 0) * 1000, std_ttft_ms=np.std(ttfts or 0) * 1000, - p99_ttft_ms=np.percentile(ttfts or 0, 99) * 1000, + median_ttft_ms=np.median(ttfts or 0) * 1000, + percentiles_ttft_ms=[(p, np.percentile(ttfts or 0, p) * 1000) + for p in selected_percentiles], mean_tpot_ms=np.mean(tpots or 0) * 1000, - median_tpot_ms=np.median(tpots or 0) * 1000, std_tpot_ms=np.std(tpots or 0) * 1000, - p99_tpot_ms=np.percentile(tpots or 0, 99) * 1000, + median_tpot_ms=np.median(tpots or 0) * 1000, + percentiles_tpot_ms=[(p, np.percentile(tpots or 0, p) * 1000) + for p in selected_percentiles], mean_itl_ms=np.mean(itls or 0) * 1000, - median_itl_ms=np.median(itls or 0) * 1000, std_itl_ms=np.std(itls or 0) * 1000, - p99_itl_ms=np.percentile(itls or 0, 99) * 1000, + median_itl_ms=np.median(itls or 0) * 1000, + percentiles_itl_ms=[(p, np.percentile(itls or 0, p) * 1000) + for p in selected_percentiles], + mean_e2el_ms=np.median(e2els or 0) * 1000, + std_e2el_ms=np.std(e2els or 0) * 1000, + median_e2el_ms=np.mean(e2els or 0) * 1000, + percentiles_e2el_ms=[(p, np.percentile(e2els or 0, p) * 1000) + for p in selected_percentiles], ) return metrics, actual_output_lens @@ -299,11 +328,14 @@ async def benchmark( model_id: str, tokenizer: PreTrainedTokenizerBase, input_requests: List[Tuple[str, int, int]], + logprobs: Optional[int], best_of: int, use_beam_search: bool, request_rate: float, disable_tqdm: bool, profile: bool, + selected_percentile_metrics: List[str], + selected_percentiles: List[str], ): if backend in ASYNC_REQUEST_FUNCS: request_func = ASYNC_REQUEST_FUNCS[backend] @@ -318,6 +350,7 @@ async def benchmark( api_url=api_url, prompt_len=test_prompt_len, output_len=test_output_len, + logprobs=logprobs, best_of=best_of, use_beam_search=use_beam_search, ) @@ -337,6 +370,7 @@ async def benchmark( api_url=base_url + "/start_profile", prompt_len=test_prompt_len, output_len=test_output_len, + logprobs=logprobs, best_of=best_of, use_beam_search=use_beam_search, ) @@ -358,6 +392,7 @@ async def benchmark( api_url=api_url, prompt_len=prompt_len, output_len=output_len, + logprobs=logprobs, best_of=best_of, use_beam_search=use_beam_search, ) @@ -375,6 +410,7 @@ async def benchmark( api_url=base_url + "/stop_profile", prompt_len=test_prompt_len, output_len=test_output_len, + logprobs=logprobs, best_of=best_of, use_beam_search=use_beam_search, ) @@ -392,6 +428,8 @@ async def benchmark( outputs=outputs, dur_s=benchmark_duration, tokenizer=tokenizer, + selected_percentile_metrics=selected_percentile_metrics, + selected_percentiles=selected_percentiles, ) print("{s:{c}^{n}}".format(s=' Serving Benchmark Result ', n=50, c='=')) @@ -403,27 +441,10 @@ async def benchmark( metrics.total_output)) print("{:<40} {:<10.2f}".format("Request throughput (req/s):", metrics.request_throughput)) - print("{:<40} {:<10.2f}".format("Input token throughput (tok/s):", - metrics.input_throughput)) print("{:<40} {:<10.2f}".format("Output token throughput (tok/s):", metrics.output_throughput)) - print("{s:{c}^{n}}".format(s='Time to First Token', n=50, c='-')) - print("{:<40} {:<10.2f}".format("Mean TTFT (ms):", metrics.mean_ttft_ms)) - print("{:<40} {:<10.2f}".format("Median TTFT (ms):", - metrics.median_ttft_ms)) - print("{:<40} {:<10.2f}".format("P99 TTFT (ms):", metrics.p99_ttft_ms)) - print("{s:{c}^{n}}".format(s='Time per Output Token (excl. 1st token)', - n=50, - c='-')) - print("{:<40} {:<10.2f}".format("Mean TPOT (ms):", metrics.mean_tpot_ms)) - print("{:<40} {:<10.2f}".format("Median TPOT (ms):", - metrics.median_tpot_ms)) - print("{:<40} {:<10.2f}".format("P99 TPOT (ms):", metrics.p99_tpot_ms)) - print("{s:{c}^{n}}".format(s='Inter-token Latency', n=50, c='-')) - print("{:<40} {:<10.2f}".format("Mean ITL (ms):", metrics.mean_itl_ms)) - print("{:<40} {:<10.2f}".format("Median ITL (ms):", metrics.median_itl_ms)) - print("{:<40} {:<10.2f}".format("P99 ITL (ms):", metrics.p99_itl_ms)) - print("=" * 50) + print("{:<40} {:<10.2f}".format("Total Token throughput (tok/s):", + metrics.total_token_throughput)) result = { "duration": benchmark_duration, @@ -431,20 +452,8 @@ async def benchmark( "total_input_tokens": metrics.total_input, "total_output_tokens": metrics.total_output, "request_throughput": metrics.request_throughput, - "input_throughput": metrics.input_throughput, "output_throughput": metrics.output_throughput, - "mean_ttft_ms": metrics.mean_ttft_ms, - "median_ttft_ms": metrics.median_ttft_ms, - "std_ttft_ms": metrics.std_ttft_ms, - "p99_ttft_ms": metrics.p99_ttft_ms, - "mean_tpot_ms": metrics.mean_tpot_ms, - "median_tpot_ms": metrics.median_tpot_ms, - "std_tpot_ms": metrics.std_tpot_ms, - "p99_tpot_ms": metrics.p99_tpot_ms, - "mean_itl_ms": metrics.mean_itl_ms, - "median_itl_ms": metrics.median_itl_ms, - "std_itl_ms": metrics.std_itl_ms, - "p99_itl_ms": metrics.p99_itl_ms, + "total_token_throughput": metrics.total_token_throughput, "input_lens": [output.prompt_len for output in outputs], "output_lens": actual_output_lens, "ttfts": [output.ttft for output in outputs], @@ -452,6 +461,47 @@ async def benchmark( "generated_texts": [output.generated_text for output in outputs], "errors": [output.error for output in outputs], } + + def process_one_metric( + # E.g., "ttft" + metric_attribute_name: str, + # E.g., "TTFT" + metric_name: str, + # E.g., "Time to First Token" + metric_header: str, + ): + # This function print and add statistics of the specified + # metric. + if metric_attribute_name not in selected_percentile_metrics: + return + print("{s:{c}^{n}}".format(s=metric_header, n=50, c='-')) + print("{:<40} {:<10.2f}".format( + f"Mean {metric_name} (ms):", + getattr(metrics, f"mean_{metric_attribute_name}_ms"))) + print("{:<40} {:<10.2f}".format( + f"Median {metric_name} (ms):", + getattr(metrics, f"median_{metric_attribute_name}_ms"))) + result[f"mean_{metric_attribute_name}_ms"] = getattr( + metrics, f"mean_{metric_attribute_name}_ms") + result[f"median_{metric_attribute_name}_ms"] = getattr( + metrics, f"median_{metric_attribute_name}_ms") + result[f"std_{metric_attribute_name}_ms"] = getattr( + metrics, f"std_{metric_attribute_name}_ms") + for p, value in getattr(metrics, + f"percentiles_{metric_attribute_name}_ms"): + p_word = str(int(p)) if int(p) == p else str(p) + print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", + value)) + result[f"p{p_word}_{metric_attribute_name}_ms"] = value + + process_one_metric("ttft", "TTFT", "Time to First Token") + process_one_metric("tpot", "TPOT", + "Time per Output Token (excl. 1st token)") + process_one_metric("itl", "ITL", "Inter-token Latency") + process_one_metric("e2el", "E2EL", "End-to-end Latency") + + print("=" * 50) + return result @@ -527,6 +577,7 @@ def main(args: argparse.Namespace): elif args.dataset_name == "random": input_requests = sample_random_requests( + prefix_len=args.random_prefix_len, input_len=args.random_input_len, output_len=args.random_output_len, num_prompts=args.num_prompts, @@ -545,11 +596,16 @@ def main(args: argparse.Namespace): model_id=model_id, tokenizer=tokenizer, input_requests=input_requests, + logprobs=args.logprobs, best_of=args.best_of, use_beam_search=args.use_beam_search, request_rate=args.request_rate, disable_tqdm=args.disable_tqdm, profile=args.profile, + selected_percentile_metrics=args.percentile_metrics.split(","), + selected_percentiles=[ + float(p) for p in args.metric_percentiles.split(",") + ], )) # Save config and results to json @@ -682,6 +738,16 @@ def main(args: argparse.Namespace): help= "Number of output tokens per request, used only for sonnet dataset.", ) + parser.add_argument( + "--logprobs", + type=int, + default=None, + help=("Number of logprobs-per-token to compute & return as part of " + "the request. If unspecified, then either (1) if beam search " + "is disabled, no logprobs are computed & a single dummy " + "logprob is returned for each token; or (2) if beam search " + "is enabled 1 logprob per token is computed"), + ) parser.add_argument( "--sonnet-prefix-len", type=int, @@ -710,6 +776,14 @@ def main(args: argparse.Namespace): help="Range of sampled ratio of input/output length, " "used only for random sampling.", ) + parser.add_argument( + "--random-prefix-len", + type=int, + default=0, + help="Number of fixed prefix tokens before random " + " context. The length range of context in a random " + " request is [random-prefix-len, " + " random-prefix-len + random-prefix-len * random-range-ratio).") parser.add_argument( "--request-rate", type=float, @@ -765,6 +839,23 @@ def main(args: argparse.Namespace): "{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json" " format.", ) + parser.add_argument( + "--percentile-metrics", + type=str, + default="ttft,tpot,itl", + help="Comma-seperated list of selected metrics to report percentils. " + "This argument specifies the metrics to report percentiles. " + "Allowed metric names are \"ttft\", \"tpot\", \"itl\", \"e2el\". " + "Default value is \"ttft,tpot,itl\".") + parser.add_argument( + "--metric-percentiles", + type=str, + default="99", + help="Comma-seperated list of percentiles for selected metrics. " + "To report 25-th, 50-th, and 75-th percentiles, use \"25,50,75\". " + "Default value is \"99\". " + "Use \"--percentile-metrics\" to select metrics.", + ) args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index eaf256f7cb8c2..94549d84fb4e4 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -6,13 +6,16 @@ from typing import List, Optional, Tuple import torch +import uvloop from tqdm import tqdm from transformers import (AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerBase) -from vllm.engine.arg_utils import EngineArgs +from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args) from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS -from vllm.utils import FlexibleArgumentParser +from vllm.utils import FlexibleArgumentParser, merge_async_iterators def sample_requests( @@ -135,6 +138,93 @@ def run_vllm( return end - start +async def run_vllm_async( + requests: List[Tuple[str, int, int]], + model: str, + tokenizer: str, + quantization: Optional[str], + tensor_parallel_size: int, + seed: int, + n: int, + use_beam_search: bool, + trust_remote_code: bool, + dtype: str, + max_model_len: Optional[int], + enforce_eager: bool, + kv_cache_dtype: str, + quantization_param_path: Optional[str], + device: str, + enable_prefix_caching: bool, + enable_chunked_prefill: bool, + max_num_batched_tokens: int, + distributed_executor_backend: Optional[str], + gpu_memory_utilization: float = 0.9, + num_scheduler_steps: int = 1, + use_v2_block_manager: bool = False, + download_dir: Optional[str] = None, + load_format: str = EngineArgs.load_format, + disable_async_output_proc: bool = False, + disable_frontend_multiprocessing: bool = False, +) -> float: + from vllm import SamplingParams + engine_args = AsyncEngineArgs( + model=model, + tokenizer=tokenizer, + quantization=quantization, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + quantization_param_path=quantization_param_path, + device=device, + enable_prefix_caching=enable_prefix_caching, + download_dir=download_dir, + enable_chunked_prefill=enable_chunked_prefill, + max_num_batched_tokens=max_num_batched_tokens, + distributed_executor_backend=distributed_executor_backend, + load_format=load_format, + num_scheduler_steps=num_scheduler_steps, + use_v2_block_manager=use_v2_block_manager, + disable_async_output_proc=disable_async_output_proc, + worker_use_ray=False, + engine_use_ray=False, + disable_log_requests=True, + ) + + async with build_async_engine_client_from_engine_args( + engine_args, disable_frontend_multiprocessing) as llm: + + # Add the requests to the engine. + prompts: List[str] = [] + sampling_params: List[SamplingParams] = [] + for prompt, _, output_len in requests: + prompts.append(prompt) + sampling_params.append( + SamplingParams( + n=n, + temperature=0.0 if use_beam_search else 1.0, + top_p=1.0, + use_beam_search=use_beam_search, + ignore_eos=True, + max_tokens=output_len, + )) + + generators = [] + start = time.perf_counter() + for i, (prompt, sp) in enumerate(zip(prompts, sampling_params)): + generator = llm.generate(prompt, sp, request_id=f"test{i}") + generators.append(generator) + all_gens = merge_async_iterators(*generators) + async for i, res in all_gens: + pass + end = time.perf_counter() + return end - start + + def run_hf( requests: List[Tuple[str, int, int]], model: str, @@ -230,7 +320,7 @@ def main(args: argparse.Namespace): args.output_len) if args.backend == "vllm": - elapsed_time = run_vllm( + run_args = [ requests, args.model, args.tokenizer, args.quantization, args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.trust_remote_code, args.dtype, args.max_model_len, @@ -240,7 +330,14 @@ def main(args: argparse.Namespace): args.max_num_batched_tokens, args.distributed_executor_backend, args.gpu_memory_utilization, args.num_scheduler_steps, args.use_v2_block_manager, args.download_dir, args.load_format, - args.disable_async_output_proc) + args.disable_async_output_proc + ] + + if args.async_engine: + run_args.append(args.disable_frontend_multiprocessing) + elapsed_time = uvloop.run(run_vllm_async(*run_args)) + else: + elapsed_time = run_vllm(*run_args) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -426,6 +523,14 @@ def main(args: argparse.Namespace): action='store_true', default=False, help="Disable async output processor for vLLM backend.") + parser.add_argument("--async-engine", + action='store_true', + default=False, + help="Use vLLM async engine rather than LLM class.") + parser.add_argument("--disable-frontend-multiprocessing", + action='store_true', + default=False, + help="Disable decoupled async engine frontend.") args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu new file mode 100644 index 0000000000000..88a64a8ece585 --- /dev/null +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -0,0 +1,700 @@ +// clang-format off +// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu +// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu +#include +#include +#include + +#include "causal_conv1d.h" +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#include +#include + +#include "static_switch.h" + + + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + using weight_t = at::Half; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + using weight_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + + +template +void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template +void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +template +void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +void set_conv_params_fwd(ConvParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t width, + // device pointers + const at::Tensor x, + const at::Tensor weight, + const at::Tensor out, + void* bias_ptr, + bool silu_activation) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.width = width; + + params.silu_activation = silu_activation; + + // Set the pointers and strides. + params.x_ptr = x.data_ptr(); + params.weight_ptr = weight.data_ptr(); + params.bias_ptr = bias_ptr; + params.out_ptr = out.data_ptr(); + // All stride are in elements, not bytes. + params.x_batch_stride = x.stride(0); + params.x_c_stride = x.stride(1); + params.x_l_stride = x.stride(-1); + params.weight_c_stride = weight.stride(0); + params.weight_width_stride = weight.stride(1); + params.out_batch_stride = out.stride(0); + params.out_c_stride = out.stride(1); + params.out_l_stride = out.stride(-1); +} + + +at::Tensor +causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, + const c10::optional &bias_, + const c10::optional &seq_idx_, + const c10::optional &initial_states_, + const c10::optional &final_states_out_, + bool silu_activation) { + auto input_type = x.scalar_type(); + auto weight_type = weight.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(weight.is_cuda()); + + const auto sizes = x.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int width = weight.size(-1); + + CHECK_SHAPE(x, batch_size, dim, seqlen); + CHECK_SHAPE(weight, dim, width); + + TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); + const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; + + if (is_channel_last) { + TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now"); + TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8"); + } + TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); + + if (bias_.has_value()) { + auto bias = bias_.value(); + TORCH_CHECK(bias.scalar_type() == weight_type); + TORCH_CHECK(bias.is_cuda()); + TORCH_CHECK(bias.stride(-1) == 1); + CHECK_SHAPE(bias, dim); + } + + if (seq_idx_.has_value()) { + TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout"); + auto seq_idx = seq_idx_.value(); + TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32); + TORCH_CHECK(seq_idx.is_cuda()); + TORCH_CHECK(seq_idx.is_contiguous()); + CHECK_SHAPE(seq_idx, batch_size, seqlen); + } + + at::Tensor out = torch::empty_like(x); + + ConvParamsBase params; + set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, + bias_.has_value() ? bias_.value().data_ptr() : nullptr, + silu_activation); + + if (seq_idx_.has_value()) { + params.seq_idx_ptr = seq_idx_.value().data_ptr(); + } else { + params.seq_idx_ptr = nullptr; + } + + if (initial_states_.has_value()) { + TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout"); + auto initial_states = initial_states_.value(); + TORCH_CHECK(initial_states.scalar_type() == input_type); + TORCH_CHECK(initial_states.is_cuda()); + CHECK_SHAPE(initial_states, batch_size, dim, width - 1); + TORCH_CHECK(initial_states.stride(1) == 1); + params.initial_states_ptr = initial_states.data_ptr(); + params.initial_states_batch_stride = initial_states.stride(0); + params.initial_states_c_stride = initial_states.stride(1); + params.initial_states_l_stride = initial_states.stride(2); + } else { + params.initial_states_ptr = nullptr; + } + + if (final_states_out_.has_value()) { + TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout"); + auto final_states = final_states_out_.value(); + TORCH_CHECK(final_states.scalar_type() == input_type); + TORCH_CHECK(final_states.is_cuda()); + CHECK_SHAPE(final_states, batch_size, dim, width - 1); + TORCH_CHECK(final_states.stride(1) == 1); + params.final_states_ptr = final_states.data_ptr(); + params.final_states_batch_stride = final_states.stride(0); + params.final_states_c_stride = final_states.stride(1); + params.final_states_l_stride = final_states.stride(2); + } else { + params.final_states_ptr = nullptr; + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { + if (!is_channel_last) { + causal_conv1d_fwd_cuda(params, stream); + } else { + causal_conv1d_channellast_fwd_cuda(params, stream); + } + }); + return out; +} + + +at::Tensor +causal_conv1d_update(const at::Tensor &x, + const at::Tensor &conv_state, + const at::Tensor &weight, + const c10::optional &bias_, + bool silu_activation) { + auto input_type = x.scalar_type(); + auto weight_type = weight.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations"); + TORCH_CHECK(conv_state.scalar_type() == input_type); + + TORCH_CHECK(x.is_cuda()); + TORCH_CHECK(conv_state.is_cuda()); + TORCH_CHECK(weight.is_cuda()); + + const auto sizes = x.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int width = weight.size(-1); + + CHECK_SHAPE(x, batch_size, dim); + CHECK_SHAPE(conv_state, batch_size, dim, width); + CHECK_SHAPE(weight, dim, width); + + TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); + + if (bias_.has_value()) { + auto bias = bias_.value(); + TORCH_CHECK(bias.scalar_type() == weight_type); + TORCH_CHECK(bias.is_cuda()); + TORCH_CHECK(bias.stride(-1) == 1); + CHECK_SHAPE(bias, dim); + } + + at::Tensor out = torch::empty_like(x); + + ConvParamsBase params; + set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out, + bias_.has_value() ? bias_.value().data_ptr() : nullptr, + silu_activation); + params.conv_state_ptr = conv_state.data_ptr(); + // All stride are in elements, not bytes. + params.conv_state_batch_stride = conv_state.stride(0); + params.conv_state_c_stride = conv_state.stride(1); + params.conv_state_l_stride = conv_state.stride(2); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)x.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] { + causal_conv1d_update_cuda(params, stream); + }); + return out; +} + +template +struct Causal_conv1d_fwd_kernel_traits { + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kWidth = kWidth_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : 8; + static_assert(kWidth <= kNElts); + static constexpr bool kIsVecLoad = kIsVecLoad_; + using vec_t = typename BytesToType::Type; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + static constexpr int kSmemIOSize = kIsVecLoad + ? 0 + : custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)}); + static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts; + static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize; +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_fwd_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNElts = Ktraits::kNElts; + static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; + using input_t = typename Ktraits::input_t; + using vec_t = typename Ktraits::vec_t; + using weight_t = typename Ktraits::weight_t; + + // Shared memory. + extern __shared__ char smem_[]; + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_vec = reinterpret_cast(smem_); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_store_vec = reinterpret_cast(smem_); + vec_t *smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + + const int tidx = threadIdx.x; + const int batch_id = blockIdx.x; + const int channel_id = blockIdx.y; + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + channel_id * params.x_c_stride; + weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + channel_id * params.out_c_stride; + float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); + + // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0. + if (tidx == 0) { + input_t zeros[kNElts] = {0}; + smem_exchange[kNThreads - 1] = reinterpret_cast(zeros)[0]; + } + + float weight_vals[kWidth]; + #pragma unroll + for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } + + constexpr int kChunkSize = kNThreads * kNElts; + const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize; + for (int chunk = 0; chunk < n_chunks; ++chunk) { + input_t x_vals_load[2 * kNElts] = {0}; + if constexpr(kIsVecLoad) { + typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts); + } else { + __syncthreads(); + typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize); + } + x += kChunkSize; + __syncthreads(); + // Thread kNThreads - 1 don't write yet, so that thread 0 can read + // the last elements of the previous chunk. + if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } + __syncthreads(); + reinterpret_cast(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1]; + __syncthreads(); + // Now thread kNThreads - 1 can write the last elements of the current chunk. + if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast(x_vals_load)[1]; } + + float x_vals[2 * kNElts]; + #pragma unroll + for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); } + + float out_vals[kNElts]; + #pragma unroll + for (int i = 0; i < kNElts; ++i) { + out_vals[i] = bias_val; + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)]; + } + } + + if (params.silu_activation) { + #pragma unroll + for (int i = 0; i < kNElts; ++i) { + out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); + } + } + + input_t out_vals_store[kNElts]; + #pragma unroll + for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; } + if constexpr(kIsVecLoad) { + typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(out), reinterpret_cast(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts); + } else { + typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize); + } + out += kChunkSize; + } +} + + +template +void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; + BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { + using Ktraits = Causal_conv1d_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize; + dim3 grid(params.batch, params.dim); + + auto kernel = &causal_conv1d_fwd_kernel; + + if (kSmemSize >= 48 * 1024) { + #ifndef USE_ROCM + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + #else + // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function. + C10_CUDA_CHECK(cudaFuncSetAttribute( + (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl; + #endif + } + kernel<<>>(params); + + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +template +void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream); + } +} + +template +struct Causal_conv1d_channellast_fwd_kernel_traits { + // The cache line is 128 bytes, and we try to read 16 bytes per thread. + // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension. + // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128 + // threads). Each each load is 16 x 32|64 elements in the L x C dimensions. + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static_assert(kNThreads % 32 == 0); + static constexpr int kNWarps = kNThreads / 32; + static constexpr int kWidth = kWidth_; + static constexpr int kChunkSizeL = kChunkSizeL_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : 8; + static constexpr int kNEltsPerRow = 128 / kNBytes; + static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now + static_assert(kNThreadsPerRow * kNBytes * kNElts == 128); + static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now + static_assert(kNColsPerWarp * kNThreadsPerRow == 32); + static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps; + static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad; + static_assert(kNLoads * kNColsPerLoad == kChunkSizeL); + static constexpr bool kIsVecLoad = kIsVecLoad_; + using vec_t = typename BytesToType::Type; + // using BlockLoadT = cub::BlockLoad; + // using BlockStoreT = cub::BlockStore; + // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage), + // sizeof(typename BlockStoreT::TempStorage)}); + // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes; +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNElts = Ktraits::kNElts; + constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow; + constexpr int kLPerLoad = Ktraits::kNColsPerLoad; + constexpr int kChunkSizeL = Ktraits::kChunkSizeL; + constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; + using input_t = typename Ktraits::input_t; + using vec_t = typename Ktraits::vec_t; + using weight_t = typename Ktraits::weight_t; + + // Shared memory. + __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts]; + + const int batch_id = blockIdx.x; + const int chunk_l_id = blockIdx.y; + const int chunk_c_id = blockIdx.z; + const int tid = threadIdx.x; + const int l_idx = tid / kNThreadsPerC; + const int c_idx = tid % kNThreadsPerC; + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + weight_t *weight = reinterpret_cast(params.weight_ptr) + + chunk_c_id * kChunkSizeC * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast(params.seq_idx_ptr) + + batch_id * params.seqlen + chunk_l_id * kChunkSizeL; + input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr + : reinterpret_cast(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + // The last L-chunk will also have enough info to write to final states, since it also contain a few x values + // from the previous L-chunk. + input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr + : reinterpret_cast(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; + + #pragma unroll + for (int l = 0; l < Ktraits::kNLoads; ++l) { + input_t x_vals_load[kNElts] = {0}; + if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x + l * kLPerLoad * params.x_l_stride); + } + reinterpret_cast(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; + } + // Load the elements from the previous chunk that are needed for convolution. + if (l_idx < kWidth - 1) { + input_t x_vals_load[kNElts] = {0}; + if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0 + && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x - (kWidth - 1) * params.x_l_stride); + } else if (initial_states != nullptr + && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0 + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(initial_states); + } + reinterpret_cast(x_smem[l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; + } + + __syncthreads(); + + if (final_states != nullptr + && l_idx < kWidth - 1 + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + // x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1) + // So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx] + *reinterpret_cast(final_states) = reinterpret_cast(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx]; + } + + constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL); + static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC); + constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread; + static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL); + // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity + static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0); + static_assert((kLPerThread & (kLPerThread - 1)) == 0); + static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0); + static_assert(kNThreadsPerRow <= 32); + + const int row_idx = tid / kNThreadsPerRow; + const int col_idx = tid % kNThreadsPerRow; + + float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]); + float weight_vals[kWidth] = {0}; + if (chunk_c_id * kChunkSizeC + row_idx < params.dim) { + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride]; + } + } + float x_vals[kWidth - 1 + kLPerThread]; + #pragma unroll + for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { + x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); + } + int seq_idx_thread[kWidth - 1 + kLPerThread]; + if constexpr (kHasSeqIdx) { + #pragma unroll + for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { + seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1; + } + } + + float out_vals[kLPerThread]; + #pragma unroll + for (int i = 0; i < kLPerThread; ++i) { + out_vals[i] = bias_val; + const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1]; + #pragma unroll + for (int w = 0; w < kWidth; ++w) { + if constexpr (!kHasSeqIdx) { + out_vals[i] += weight_vals[w] * x_vals[i + w]; + } else { + out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f; + } + } + if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); } + } + + __syncthreads(); + #pragma unroll + for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; } + __syncthreads(); + + #pragma unroll + for (int l = 0; l < Ktraits::kNLoads; ++l) { + input_t out_vals_store[kNElts]; + reinterpret_cast(out_vals_store)[0] = reinterpret_cast(x_smem[l * kLPerLoad + l_idx])[c_idx]; + if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen + && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { + *reinterpret_cast(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast(out_vals_store)[0]; + } + } + +} + +template +void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] { + using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits; + // constexpr int kSmemSize = Ktraits::kSmemSize; + constexpr int kChunkSizeL = Ktraits::kChunkSizeL; + constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; + const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL; + const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC; + dim3 grid(params.batch, n_chunks_L, n_chunks_C); + dim3 block(Ktraits::kNThreads); + auto kernel = &causal_conv1d_channellast_fwd_kernel; + // if (kSmemSize >= 48 * 1024) { + // C10_CUDA_CHECK(cudaFuncSetAttribute( + // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + // } + // kernel<<>>(params); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); +} + +template +void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream); + } +} + +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); + +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +/////// + + + + +template +struct Causal_conv1d_update_kernel_traits { + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + static constexpr int kWidth = kWidth_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads) +void causal_conv1d_update_kernel(ConvParamsBase params) { + constexpr int kWidth = Ktraits::kWidth; + constexpr int kNThreads = Ktraits::kNThreads; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + + const int tidx = threadIdx.x; + const int batch_id = blockIdx.x; + const int channel_id = blockIdx.y * kNThreads + tidx; + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + + channel_id * params.x_c_stride; + input_t *conv_state = reinterpret_cast(params.conv_state_ptr) + batch_id * params.conv_state_batch_stride + + channel_id * params.conv_state_c_stride; + weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + channel_id * params.out_c_stride; + float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); + + float weight_vals[kWidth] = {0}; + if (channel_id < params.dim) { + #pragma unroll + for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } + } + + float x_vals[kWidth] = {0}; + if (channel_id < params.dim) { + #pragma unroll + for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); } + x_vals[kWidth - 1] = float(x[0]); + #pragma unroll + for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); } + } + + float out_val = bias_val; + #pragma unroll + for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; } + if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } + if (channel_id < params.dim) { out[0] = input_t(out_val); } +} + +template +void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { + using Ktraits = Causal_conv1d_update_kernel_traits; + dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); + auto kernel = &causal_conv1d_update_kernel; + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); +} + +template +void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { + if (params.width == 2) { + causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream); + } else if (params.width == 3) { + causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream); + } else if (params.width == 4) { + causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream); + } +} + +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); +template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h new file mode 100644 index 0000000000000..bb25314c8bbbd --- /dev/null +++ b/csrc/mamba/causal_conv1d/causal_conv1d.h @@ -0,0 +1,144 @@ +/****************************************************************************** + * Copyright (c) 2024, Tri Dao. + ******************************************************************************/ +// clang-format off +// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h +#pragma once + +#include +#include +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct ConvParamsBase { + using index_t = uint32_t; + + int batch, dim, seqlen, width; + bool silu_activation; + + index_t x_batch_stride; + index_t x_c_stride; + index_t x_l_stride; + index_t weight_c_stride; + index_t weight_width_stride; + index_t out_batch_stride; + index_t out_c_stride; + index_t out_l_stride; + + index_t conv_state_batch_stride; + index_t conv_state_c_stride; + index_t conv_state_l_stride; + + // Common data pointers. + void *__restrict__ x_ptr; + void *__restrict__ weight_ptr; + void *__restrict__ bias_ptr; + void *__restrict__ out_ptr; + + void *__restrict__ conv_state_ptr; + + void *__restrict__ seq_idx_ptr; + + // No __restrict__ since initial_states could be the same as final_states. + void * initial_states_ptr; + index_t initial_states_batch_stride; + index_t initial_states_l_stride; + index_t initial_states_c_stride; + + void * final_states_ptr; + index_t final_states_batch_stride; + index_t final_states_l_stride; + index_t final_states_c_stride; +}; + + +#ifndef USE_ROCM + #include + + template + __device__ inline T shuffle_xor(T val, int offset) { + return __shfl_xor_sync(uint32_t(-1), val, offset); + } + + constexpr size_t custom_max(std::initializer_list ilist) + { + return std::max(ilist); + } + + template + constexpr T constexpr_min(T a, T b) { + return std::min(a, b); + } + +#else + #include + + template + __device__ inline T shuffle_xor(T val, int offset) { + return __shfl_xor(val, offset); + } + constexpr size_t custom_max(std::initializer_list ilist) + { + return *std::max_element(ilist.begin(), ilist.end()); + } + + template + constexpr T constexpr_min(T a, T b) { + return a < b ? a : b; + } +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct BytesToType {}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct SumOp { +__device__ inline T operator()(T const & x, T const & y) { return x + y; } +}; + +template +struct Allreduce { + static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4); + template + static __device__ inline T run(T x, Operator &op) { + constexpr int OFFSET = THREADS / 2; + x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET)); + return Allreduce::run(x, op); + } +}; + +template<> +struct Allreduce<2> { +template +static __device__ inline T run(T x, Operator &op) { + x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1)); + return x; +} +}; diff --git a/csrc/mamba/causal_conv1d/static_switch.h b/csrc/mamba/causal_conv1d/static_switch.h new file mode 100644 index 0000000000000..ef74bf447f840 --- /dev/null +++ b/csrc/mamba/causal_conv1d/static_switch.h @@ -0,0 +1,28 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h +// clang-format off +// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h + +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + static constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + static constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h new file mode 100644 index 0000000000000..0070c92f6cd0f --- /dev/null +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -0,0 +1,276 @@ +/****************************************************************************** + * Copyright (c) 2023, Tri Dao. + ******************************************************************************/ +// clang-format off +// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan.h + +#pragma once + +#ifndef USE_ROCM + #include +#else + #include +#endif +#include +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SSMParamsBase { + using index_t = uint32_t; + + int batch, dim, seqlen, dstate, n_groups, n_chunks; + int dim_ngroups_ratio; + bool is_variable_B; + bool is_variable_C; + + bool delta_softplus; + + index_t A_d_stride; + index_t A_dstate_stride; + index_t B_batch_stride; + index_t B_d_stride; + index_t B_dstate_stride; + index_t B_group_stride; + index_t C_batch_stride; + index_t C_d_stride; + index_t C_dstate_stride; + index_t C_group_stride; + index_t u_batch_stride; + index_t u_d_stride; + index_t delta_batch_stride; + index_t delta_d_stride; + index_t z_batch_stride; + index_t z_d_stride; + index_t out_batch_stride; + index_t out_d_stride; + index_t out_z_batch_stride; + index_t out_z_d_stride; + + // Common data pointers. + void *__restrict__ A_ptr; + void *__restrict__ B_ptr; + void *__restrict__ C_ptr; + void *__restrict__ D_ptr; + void *__restrict__ u_ptr; + void *__restrict__ delta_ptr; + void *__restrict__ delta_bias_ptr; + void *__restrict__ out_ptr; + void *__restrict__ x_ptr; + void *__restrict__ z_ptr; + void *__restrict__ out_z_ptr; + void *__restrict__ index_ptr; +}; + + + + +#ifndef USE_ROCM + + constexpr size_t custom_max(std::initializer_list ilist) + { + return std::max(ilist); + } + + template + constexpr T constexpr_min(T a, T b) { + return std::min(a, b); + } + +#else + constexpr size_t custom_max(std::initializer_list ilist) + { + return *std::max_element(ilist.begin(), ilist.end()); + } + + template + constexpr T constexpr_min(T a, T b) { + return a < b ? a : b; + } +#endif + + +#define MAX_DSTATE 256 + + +inline __device__ float2 operator+(const float2 & a, const float2 & b){ + return {a.x + b.x, a.y + b.y}; +} + +inline __device__ float3 operator+(const float3 &a, const float3 &b) { + return {a.x + b.x, a.y + b.y, a.z + b.z}; +} + +inline __device__ float4 operator+(const float4 & a, const float4 & b){ + return {a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w}; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct BytesToType {}; + +template<> struct BytesToType<16> { + using Type = uint4; + static_assert(sizeof(Type) == 16); +}; + +template<> struct BytesToType<8> { + using Type = uint64_t; + static_assert(sizeof(Type) == 8); +}; + +template<> struct BytesToType<4> { + using Type = uint32_t; + static_assert(sizeof(Type) == 4); +}; + +template<> struct BytesToType<2> { + using Type = uint16_t; + static_assert(sizeof(Type) == 2); +}; + +template<> struct BytesToType<1> { + using Type = uint8_t; + static_assert(sizeof(Type) == 1); +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Converter{ + static inline __device__ void to_float(const scalar_t (&src)[N], float (&dst)[N]) { + #pragma unroll + for (int i = 0; i < N; ++i) { dst[i] = src[i]; } + } +}; + +template +struct Converter{ + static inline __device__ void to_float(const at::Half (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __half22float2(src2[i]); } + } +}; + +#if __CUDA_ARCH__ >= 800 +template +struct Converter{ + static inline __device__ void to_float(const at::BFloat16 (&src)[N], float (&dst)[N]) { + static_assert(N % 2 == 0); + auto &src2 = reinterpret_cast(src); + auto &dst2 = reinterpret_cast(dst); + #pragma unroll + for (int i = 0; i < N / 2; ++i) { dst2[i] = __bfloat1622float2(src2[i]); } + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +template struct SSMScanOp; + +template<> +struct SSMScanOp { + __device__ __forceinline__ float2 operator()(const float2 &ab0, const float2 &ab1) const { + return make_float2(ab1.x * ab0.x, ab1.x * ab0.y + ab1.y); + } +}; + +// A stateful callback functor that maintains a running prefix to be applied +// during consecutive scan operations. +template struct SSMScanPrefixCallbackOp { + using scan_t = std::conditional_t, float2, float4>; + scan_t running_prefix; + // Constructor + __device__ SSMScanPrefixCallbackOp(scan_t running_prefix_) : running_prefix(running_prefix_) {} + // Callback operator to be entered by the first warp of threads in the block. + // Thread-0 is responsible for returning a value for seeding the block-wide scan. + __device__ scan_t operator()(scan_t block_aggregate) { + scan_t old_prefix = running_prefix; + running_prefix = SSMScanOp()(running_prefix, block_aggregate); + return old_prefix; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void load_input(typename Ktraits::input_t *u, + typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadT::TempStorage &smem_load, + int seqlen) { + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_vec = reinterpret_cast(smem_load); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockLoadVecT(smem_load_vec).Load( + reinterpret_cast(u), + reinterpret_cast(u_vals) + #ifdef USE_ROCM + , Ktraits::kNThreads * Ktraits::kNLoads + #endif + + ); + } else { + typename Ktraits::BlockLoadT(smem_load).Load(u, u_vals, seqlen, 0.f); + } +} + +template +inline __device__ void load_index(int *u, + int (&u_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index, + int seqlen) { + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_index_vec = reinterpret_cast(smem_load_index); + Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load( + reinterpret_cast(u), + reinterpret_cast(u_vals) + ); + } else { + Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0); + } +} + +template +inline __device__ void load_weight(typename Ktraits::input_t *Bvar, + typename Ktraits::weight_t (&B_vals)[Ktraits::kNItems], + typename Ktraits::BlockLoadWeightT::TempStorage &smem_load_weight, + int seqlen) { + constexpr int kNItems = Ktraits::kNItems; + typename Ktraits::input_t B_vals_load[kNItems]; + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( + reinterpret_cast(Bvar), + reinterpret_cast(B_vals_load) + ); + } else { + typename Ktraits::BlockLoadWeightT(smem_load_weight).Load(Bvar, B_vals_load, seqlen, 0.f); + } + // #pragma unroll + // for (int i = 0; i < kNItems; ++i) { B_vals[i] = B_vals_load[i]; } + Converter::to_float(B_vals_load, B_vals); +} + +template +inline __device__ void store_output(typename Ktraits::input_t *out, + const float (&out_vals)[Ktraits::kNItems], + typename Ktraits::BlockStoreT::TempStorage &smem_store, + int seqlen) { + typename Ktraits::input_t write_vals[Ktraits::kNItems]; + #pragma unroll + for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } + if constexpr (Ktraits::kIsEvenLen) { + auto& smem_store_vec = reinterpret_cast(smem_store); + using vec_t = typename Ktraits::vec_t; + typename Ktraits::BlockStoreVecT(smem_store_vec).Store( + reinterpret_cast(out), + reinterpret_cast(write_vals) + ); + } else { + typename Ktraits::BlockStoreT(smem_store).Store(out, write_vals, seqlen); + } +} diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu new file mode 100644 index 0000000000000..df968dda92adc --- /dev/null +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -0,0 +1,593 @@ +// clang-format off +// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh +#include +#include +#include +#include "selective_scan.h" + +#include +#include +#include // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK + +#ifndef USE_ROCM + #include + #include + #include +#else + #include + namespace cub = hipcub; +#endif + +#include "selective_scan.h" +#include "static_switch.h" + +template +struct Selective_Scan_fwd_kernel_traits { + static_assert(kNItems_ % 4 == 0); + using input_t = input_t_; + using weight_t = weight_t_; + static constexpr int kNThreads = kNThreads_; + // Setting MinBlocksPerMP to be 3 (instead of 2) for 128 threads improves occupancy. + static constexpr int kMinBlocks = kNThreads < 128 ? 5 : 3; + static constexpr int kNItems = kNItems_; + static constexpr int kNRows = kNRows_; + static constexpr int kNBytes = sizeof(input_t); + static_assert(kNBytes == 2 || kNBytes == 4); + static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems); + static_assert(kNItems % kNElts == 0); + static constexpr int kNLoads = kNItems / kNElts; + static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kIsVariableB = kIsVariableB_; + static constexpr bool kIsVariableC = kIsVariableC_; + static constexpr bool kHasZ = kHasZ_; + static constexpr bool kUseIndex = kUseIndex_; + + static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; + static constexpr int kNLoadsIndex = kNItems / 4; + using vec_t = typename BytesToType::Type; + using scan_t = float2; + using BlockLoadT = cub::BlockLoad; + using BlockLoadVecT = cub::BlockLoad; + using BlockLoadIndexT = cub::BlockLoad; + using BlockLoadIndexVecT = cub::BlockLoad; + using BlockLoadWeightT = cub::BlockLoad; + using BlockLoadWeightVecT = cub::BlockLoad; + using BlockStoreT = cub::BlockStore; + using BlockStoreVecT = cub::BlockStore; + // using BlockScanT = cub::BlockScan; + // using BlockScanT = cub::BlockScan; + using BlockScanT = cub::BlockScan; + static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage), + sizeof(typename BlockLoadVecT::TempStorage), + sizeof(typename BlockLoadIndexT::TempStorage), + sizeof(typename BlockLoadIndexVecT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), + (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), + sizeof(typename BlockStoreT::TempStorage), + sizeof(typename BlockStoreVecT::TempStorage)}); + static constexpr int kSmemSize = kSmemIOSize + sizeof(typename BlockScanT::TempStorage); +}; + +template +__global__ __launch_bounds__(Ktraits::kNThreads, Ktraits::kMinBlocks) +void selective_scan_fwd_kernel(SSMParamsBase params) { + constexpr bool kIsVariableB = Ktraits::kIsVariableB; + constexpr bool kIsVariableC = Ktraits::kIsVariableC; + constexpr bool kHasZ = Ktraits::kHasZ; + constexpr bool kUseIndex = Ktraits::kUseIndex; + constexpr int kNThreads = Ktraits::kNThreads; + constexpr int kNItems = Ktraits::kNItems; + constexpr int kNRows = Ktraits::kNRows; + constexpr bool kDirectIO = Ktraits::kDirectIO; + using input_t = typename Ktraits::input_t; + using weight_t = typename Ktraits::weight_t; + using scan_t = typename Ktraits::scan_t; + + // Shared memory. + extern __shared__ char smem_[]; + // cast to lvalue reference of expected type + // char *smem_loadstorescan = smem_ + 2 * MAX_DSTATE * sizeof(weight_t); + // auto& smem_load = reinterpret_cast(smem_ + 2 * MAX_DSTATE * sizeof(weight_t)); + // auto& smem_load = reinterpret_cast(smem_loadstorescan); + auto& smem_load = reinterpret_cast(smem_); + auto& smem_load_weight = reinterpret_cast(smem_); + auto& smem_load_index = reinterpret_cast(smem_); + auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); + auto& smem_store = reinterpret_cast(smem_); + auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + // weight_t *smem_a = reinterpret_cast(smem_ + smem_loadstorescan_size); + // weight_t *smem_bc = reinterpret_cast(smem_a + MAX_DSTATE); + scan_t *smem_running_prefix = reinterpret_cast(smem_ + Ktraits::kSmemSize); + + const int batch_id = blockIdx.x; + const int dim_id = blockIdx.y; + const int group_id = dim_id / (params.dim_ngroups_ratio); + input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + + dim_id * kNRows * params.u_d_stride; + input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + + dim_id * kNRows * params.delta_d_stride; + weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; + weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; + input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; + scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; + int *index = !kUseIndex ? nullptr :reinterpret_cast(params.index_ptr) + batch_id * params.seqlen; + + float D_val[kNRows] = {0}; + if (params.D_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + D_val[r] = reinterpret_cast(params.D_ptr)[dim_id * kNRows + r]; + } + } + float delta_bias[kNRows] = {0}; + if (params.delta_bias_ptr != nullptr) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + delta_bias[r] = reinterpret_cast(params.delta_bias_ptr)[dim_id * kNRows + r]; + } + } + + + // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { + // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; + // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; + // } + + constexpr int kChunkSize = kNThreads * kNItems; + for (int chunk = 0; chunk < params.n_chunks; ++chunk) { + input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; + int index_vals_load[kNRows][kNItems]; + + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + if constexpr (!kDirectIO) { __syncthreads(); } + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); + if constexpr (kUseIndex) { + load_index(index + r * params.delta_d_stride, index_vals_load[r], smem_load_index, params.seqlen - chunk * kChunkSize); + } + } + if constexpr (kUseIndex) { + index += kChunkSize; + } + u += kChunkSize; + delta += kChunkSize; + + float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float u_val = float(u_vals[r][i]); + delta_vals[r][i] = float(delta_vals_load[r][i]) + delta_bias[r]; + if (params.delta_softplus) { + delta_vals[r][i] = delta_vals[r][i] <= 20.f ? log1pf(expf(delta_vals[r][i])) : delta_vals[r][i]; + } + delta_u_vals[r][i] = delta_vals[r][i] * u_val; + out_vals[r][i] = D_val[r] * u_val; + } + } + + __syncthreads(); + for (int state_idx = 0; state_idx < params.dstate; ++state_idx) { + weight_t A_val[kNRows]; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + A_val[r] = A[state_idx * params.A_dstate_stride + r * params.A_d_stride]; + // Multiply the real part of A with LOG2E so we can use exp2f instead of expf. + constexpr float kLog2e = M_LOG2E; + A_val[r] *= kLog2e; + } + // This variable holds B * C if both B and C are constant across seqlen. If only B varies + // across seqlen, this holds C. If only C varies across seqlen, this holds B. + // If both B and C vary, this is unused. + weight_t BC_val[kNRows]; + weight_t B_vals[kNItems], C_vals[kNItems]; + if constexpr (kIsVariableB) { + load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, + smem_load_weight, (params.seqlen - chunk * kChunkSize) * (1)); + if constexpr (!kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } + } + if constexpr (kIsVariableC) { + auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; + load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, + smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (1 )); + if constexpr (!kIsVariableB) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride]; + } + } + } + if constexpr (!kIsVariableB && !kIsVariableC) { + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + BC_val[r] = B[state_idx * params.B_dstate_stride + r * params.B_d_stride] * C[state_idx * params.C_dstate_stride + r * params.C_d_stride]; + } + } + + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if (r > 0) { __syncthreads(); } // Scan could be using the same smem + scan_t thread_data[kNItems]; + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), + !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); + + // Reset A bar for cumulative sequences (Real) + if constexpr (kUseIndex) { + if (index_vals_load[r][i] == 0) { + thread_data[i].x = 0.f; + } + } + + if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct + if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { + thread_data[i] = make_float2(1.f, 0.f); + } + } + } + // Initialize running total + scan_t running_prefix; + // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read + running_prefix = chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] : ( threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f)); + // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); + SSMScanPrefixCallbackOp prefix_op(running_prefix); + typename Ktraits::BlockScanT(smem_scan).InclusiveScan( + thread_data, thread_data, SSMScanOp(), prefix_op + ); + // There's a syncthreads in the scan op, so we don't need to sync here. + // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. + if (threadIdx.x == 0) { + smem_running_prefix[state_idx] = prefix_op.running_prefix; + x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix; + } + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + const weight_t C_val = !kIsVariableC + ? BC_val[r] + : (!kIsVariableB ? BC_val[r] * C_vals[i] : C_vals[i]); + out_vals[r][i] += thread_data[i].y * C_val; + } + } + } + + input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; + __syncthreads(); + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + if constexpr (!kDirectIO) { + if (r > 0) { __syncthreads(); } + } + store_output(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + + if constexpr (kHasZ) { + input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; + input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; + #pragma unroll + for (int r = 0; r < kNRows; ++r) { + input_t z_vals[kNItems]; + __syncthreads(); + load_input(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize); + #pragma unroll + for (int i = 0; i < kNItems; ++i) { + float z_val = z_vals[i]; + out_vals[r][i] *= z_val / (1 + expf(-z_val)); + } + __syncthreads(); + store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + } + } + + Bvar += kChunkSize * 1; + Cvar += kChunkSize * 1; + } +} + +template +void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { + // Only kNRows == 1 is tested for now, which ofc doesn't differ from previously when we had each block + // processing 1 row. + constexpr int kNRows = 1; + // kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size + constexpr bool kIsVariableB = true; + constexpr bool kIsVariableC = true; + constexpr bool kHasZ = true; + BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { + BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; + constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); + dim3 grid(params.batch, params.dim / kNRows); + auto kernel = &selective_scan_fwd_kernel; + if (kSmemSize >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); +} + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream) { + + #ifndef USE_ROCM + if (params.seqlen <= 128) { + selective_scan_fwd_launch<32, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 256) { + selective_scan_fwd_launch<32, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<32, 16, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + } + #else + if (params.seqlen <= 256) { + selective_scan_fwd_launch<64, 4, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 512) { + selective_scan_fwd_launch<64, 8, input_t, weight_t>(params, stream); + } else if (params.seqlen <= 1024) { + selective_scan_fwd_launch<64, 16, input_t, weight_t>(params, stream); + } else { + selective_scan_fwd_launch<128, 16, input_t, weight_t>(params, stream); + } + #endif +} + +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); +template void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") + +#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \ + if (ITYPE == at::ScalarType::Half) { \ + using input_t = at::Half; \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::BFloat16) { \ + using input_t = at::BFloat16; \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else if (ITYPE == at::ScalarType::Float) { \ + using input_t = float; \ + using weight_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \ + } + + +template +void selective_scan_fwd_cuda(SSMParamsBase ¶ms, cudaStream_t stream); + +void set_ssm_params_fwd(SSMParamsBase ¶ms, + // sizes + const size_t batch, + const size_t dim, + const size_t seqlen, + const size_t dstate, + const size_t n_groups, + const size_t n_chunks, + const bool is_variable_B, + const bool is_variable_C, + // device pointers + const torch::Tensor u, + const torch::Tensor delta, + const torch::Tensor A, + const torch::Tensor B, + const torch::Tensor C, + const torch::Tensor out, + const torch::Tensor z, + const torch::Tensor out_z, + void* D_ptr, + void* delta_bias_ptr, + void* x_ptr, + bool has_z, + bool delta_softplus, + void* index_ptr) { + + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + + params.batch = batch; + params.dim = dim; + params.seqlen = seqlen; + params.dstate = dstate; + params.n_groups = n_groups; + params.n_chunks = n_chunks; + params.dim_ngroups_ratio = dim / n_groups; + + params.delta_softplus = delta_softplus; + + params.is_variable_B = is_variable_B; + params.is_variable_C = is_variable_C; + + // Set the pointers and strides. + params.u_ptr = u.data_ptr(); + params.delta_ptr = delta.data_ptr(); + params.A_ptr = A.data_ptr(); + params.B_ptr = B.data_ptr(); + params.C_ptr = C.data_ptr(); + params.D_ptr = D_ptr; + params.delta_bias_ptr = delta_bias_ptr; + params.out_ptr = out.data_ptr(); + params.x_ptr = x_ptr; + params.z_ptr = has_z ? z.data_ptr() : nullptr; + params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; + + params.index_ptr = index_ptr; + + // All stride are in elements, not bytes. + params.A_d_stride = A.stride(0); + params.A_dstate_stride = A.stride(1); + if (!is_variable_B) { + params.B_d_stride = B.stride(0); + } else { + params.B_batch_stride = B.stride(0); + params.B_group_stride = B.stride(1); + } + params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); + if (!is_variable_C) { + params.C_d_stride = C.stride(0); + } else { + params.C_batch_stride = C.stride(0); + params.C_group_stride = C.stride(1); + } + params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); + params.u_batch_stride = u.stride(0); + params.u_d_stride = u.stride(1); + params.delta_batch_stride = delta.stride(0); + params.delta_d_stride = delta.stride(1); + if (has_z) { + params.z_batch_stride = z.stride(0); + params.z_d_stride = z.stride(1); + params.out_z_batch_stride = out_z.stride(0); + params.out_z_d_stride = out_z.stride(1); + } + params.out_batch_stride = out.stride(0); + params.out_d_stride = out.stride(1); +} + +std::vector +selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, + const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C, + const c10::optional &D_, + const c10::optional &z_, + const c10::optional &delta_bias_, + bool delta_softplus, + const c10::optional &index_, + const c10::optional &x) { + auto input_type = u.scalar_type(); + auto weight_type = A.scalar_type(); + TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); + TORCH_CHECK(weight_type == at::ScalarType::Float); + + const bool is_variable_B = B.dim() >= 3; + const bool is_variable_C = C.dim() >= 3; + + TORCH_CHECK(delta.scalar_type() == input_type); + TORCH_CHECK(B.scalar_type() == (!is_variable_B ? weight_type : input_type)); + TORCH_CHECK(C.scalar_type() == (!is_variable_C ? weight_type : input_type)); + + TORCH_CHECK(u.is_cuda()); + TORCH_CHECK(delta.is_cuda()); + TORCH_CHECK(A.is_cuda()); + TORCH_CHECK(B.is_cuda()); + TORCH_CHECK(C.is_cuda()); + + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + + const auto sizes = u.sizes(); + const int batch_size = sizes[0]; + const int dim = sizes[1]; + const int seqlen = sizes[2]; + const int dstate = A.size(1); + const int n_groups = is_variable_B ? B.size(1) : 1; + + TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); + + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + CHECK_SHAPE(A, dim, dstate); + TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size") + CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen ); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); + + TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size") + CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); + + if (D_.has_value()) { + auto D = D_.value(); + TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(D.is_cuda()); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); + CHECK_SHAPE(D, dim); + } + + if (delta_bias_.has_value()) { + auto delta_bias = delta_bias_.value(); + TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); + TORCH_CHECK(delta_bias.is_cuda()); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); + CHECK_SHAPE(delta_bias, dim); + } + if (index_.has_value()) { + auto index = index_.value(); + TORCH_CHECK(index.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(index.is_cuda()); + CHECK_SHAPE(index, batch_size, seqlen); + } + + at::Tensor z, out_z; + const bool has_z = z_.has_value(); + TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size") + z = z_.value(); + TORCH_CHECK(z.scalar_type() == input_type); + TORCH_CHECK(z.is_cuda()); + TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); + CHECK_SHAPE(z, batch_size, dim, seqlen); + out_z = torch::empty_like(z); + + const int n_chunks = (seqlen + 2048 - 1) / 2048; + // const int n_chunks = (seqlen + 1024 - 1) / 1024; + // at::Tensor out = torch::empty_like(u); + // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout + at::Tensor out = torch::empty_like(delta); + if (x.has_value()){ + auto _x = x.value(); + TORCH_CHECK(_x.scalar_type() == weight_type); + TORCH_CHECK(_x.is_cuda()); + TORCH_CHECK(_x.stride(-1) == 1); + CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2); + } + + SSMParamsBase params; + set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, + u, delta, A, B, C, out, z, out_z, + D_.has_value() ? D_.value().data_ptr() : nullptr, + delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, + x.value().data_ptr(), + has_z, + delta_softplus, + index_.has_value() ? index_.value().data_ptr() : nullptr); + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)u.get_device()}; + auto stream = at::cuda::getCurrentCUDAStream().stream(); + DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { + selective_scan_fwd_cuda(params, stream); + }); + std::vector result = {out, x.value()}; + if (has_z) { result.push_back(out_z); } + return result; +} + diff --git a/csrc/mamba/mamba_ssm/static_switch.h b/csrc/mamba/mamba_ssm/static_switch.h new file mode 100644 index 0000000000000..840cb2374a2f0 --- /dev/null +++ b/csrc/mamba/mamba_ssm/static_switch.h @@ -0,0 +1,28 @@ +// Inspired by +// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h +// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h + +// clang-format off +// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/static_switch.h +#pragma once + +/// @param COND - a boolean expression to switch by +/// @param CONST_NAME - a name given for the constexpr bool variable. +/// @param ... - code to execute for true and false +/// +/// Usage: +/// ``` +/// BOOL_SWITCH(flag, BoolConst, [&] { +/// some_function(...); +/// }); +/// ``` +#define BOOL_SWITCH(COND, CONST_NAME, ...) \ + [&] { \ + if (COND) { \ + constexpr bool CONST_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool CONST_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index 1e170e80d2f70..92184f43c9eb0 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -1737,4 +1737,4 @@ torch::Tensor marlin_gemm_moe( moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms, max_par, replicate_input, apply_weights); return c; -} \ No newline at end of file +} diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h index 01ba8ff69850d..43d264e0770d6 100644 --- a/csrc/moe/marlin_moe_ops.h +++ b/csrc/moe/marlin_moe_ops.h @@ -9,4 +9,4 @@ torch::Tensor marlin_gemm_moe( const torch::Tensor& g_idx, const torch::Tensor& perm, torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, - bool replicate_input, bool apply_weights); \ No newline at end of file + bool replicate_input, bool apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index d4d43e2c601b5..8a0e625b43fa1 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -16,7 +16,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int " "size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, " "bool replicate_input, bool apply_weights) -> Tensor"); - m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); #endif } diff --git a/csrc/ops.h b/csrc/ops.h index 6bf0cff232528..45a3868395d12 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -170,9 +170,6 @@ void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scales); -void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, - torch::Tensor lookup_table); - torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, torch::Tensor b_gptq_qzeros, torch::Tensor b_gptq_scales, torch::Tensor b_g_idx, @@ -195,6 +192,28 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); +std::vector selective_scan_fwd( + const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, + const torch::Tensor& B, const torch::Tensor& C, + const c10::optional& D_, + const c10::optional& z_, + const c10::optional& delta_bias_, bool delta_softplus, + const c10::optional& index_, + const c10::optional& x); + +at::Tensor causal_conv1d_update(const at::Tensor& x, + const at::Tensor& conv_state, + const at::Tensor& weight, + const c10::optional& bias_, + bool silu_activation); + +at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, + const c10::optional& bias_, + const c10::optional& seq_idx_, + const c10::optional& initial_states_, + const c10::optional& final_states_out_, + bool silu_activation); + #ifndef USE_ROCM using fptr_t = int64_t; fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, diff --git a/csrc/quantization/squeezellm/quant_cuda_kernel.cu b/csrc/quantization/squeezellm/quant_cuda_kernel.cu deleted file mode 100644 index 8ed918b3d7c27..0000000000000 --- a/csrc/quantization/squeezellm/quant_cuda_kernel.cu +++ /dev/null @@ -1,216 +0,0 @@ -#include -#include -#include -#include - -// half-tensor -#include -#include -#include - -#define BLOCKWIDTH 128 -#define BLOCKHEIGHT4 16 - -namespace vllm { -namespace squeezellm { - -__device__ inline unsigned int as_unsigned(int i) { - return *reinterpret_cast(&i); -} - -// 4-bit matvec kernel (LUT-based) -__global__ void NUQ4MatMulKernel( -#ifndef USE_ROCM - const half2* __restrict__ vec, -#else - const __half2* __restrict__ vec, -#endif - const int* __restrict__ mat, -#ifndef USE_ROCM - half2* __restrict__ mul, -#else - float2* __restrict__ mul, -#endif - const __half* __restrict__ lookup_table, int height, int width, int batch, - int vec_height) { - - const int blockwidth2 = BLOCKWIDTH / 2; - - int row = BLOCKHEIGHT4 * blockIdx.x; - int col = BLOCKWIDTH * blockIdx.y + threadIdx.x; - -#ifndef USE_ROCM - __shared__ half2 blockvec[blockwidth2]; -#else - __shared__ __half2 blockvec[blockwidth2]; -#endif - - __shared__ __half deq2[16][BLOCKWIDTH]; - int off = threadIdx.x; - int column_offset = col * 16; - for (int val = 0; val < 16; val += 1) { - int lut_index = column_offset + val; - deq2[val][off] = lookup_table[lut_index]; - } - - __half res; -#ifndef USE_ROCM - half2 res2; - half2 tmp2; -#else - __half2 res2; - __half2 tmp2; -#endif - - int i; - int k; - - unsigned int tmp1; - unsigned int lut_index1, lut_index2; - - for (int b = 0; b < batch; ++b) { - i = width * row + col; - res = __int2half_rd(0); - k = 0; - - __syncthreads(); - if (threadIdx.x < blockwidth2) - blockvec[threadIdx.x] = - vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 + - threadIdx.x]; - __syncthreads(); - - while (k < blockwidth2) { - tmp1 = as_unsigned(mat[i]); - -#ifndef USE_ROCM - res2 = {}; - tmp2 = {}; -#else - res2.x = __half_as_ushort(__float2half(0)); - res2.y = __half_as_ushort(__float2half(0)); - tmp2.x = __half_as_ushort(__float2half(0)); - tmp2.y = __half_as_ushort(__float2half(0)); -#endif - - lut_index1 = tmp1 & 0xF; - lut_index2 = (tmp1 >> 4) & 0xF; -#ifndef USE_ROCM - tmp2.x = deq2[lut_index1][off]; - tmp2.y = deq2[lut_index2][off]; -#else - tmp2.x = __half_as_ushort(deq2[lut_index1][off]); - tmp2.y = __half_as_ushort(deq2[lut_index2][off]); -#endif - res2 = __hfma2(tmp2, blockvec[k + 0], res2); - - lut_index1 = (tmp1 >> 8) & 0xF; - lut_index2 = (tmp1 >> 12) & 0xF; -#ifndef USE_ROCM - tmp2.x = deq2[lut_index1][off]; - tmp2.y = deq2[lut_index2][off]; -#else - tmp2.x = __half_as_ushort(deq2[lut_index1][off]); - tmp2.y = __half_as_ushort(deq2[lut_index2][off]); -#endif - res2 = __hfma2(tmp2, blockvec[k + 1], res2); - - lut_index1 = (tmp1 >> 16) & 0xF; - lut_index2 = (tmp1 >> 20) & 0xF; -#ifndef USE_ROCM - tmp2.x = deq2[lut_index1][off]; - tmp2.y = deq2[lut_index2][off]; -#else - tmp2.x = __half_as_ushort(deq2[lut_index1][off]); - tmp2.y = __half_as_ushort(deq2[lut_index2][off]); -#endif - res2 = __hfma2(tmp2, blockvec[k + 2], res2); - - lut_index1 = (tmp1 >> 24) & 0xF; - lut_index2 = (tmp1 >> 28) & 0xF; -#ifndef USE_ROCM - tmp2.x = deq2[lut_index1][off]; - tmp2.y = deq2[lut_index2][off]; -#else - tmp2.x = __half_as_ushort(deq2[lut_index1][off]); - tmp2.y = __half_as_ushort(deq2[lut_index2][off]); -#endif - res2 = __hfma2(tmp2, blockvec[k + 3], res2); - -#ifndef USE_ROCM - res = __hadd(__hadd(res2.x, res2.y), res); -#else - res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), - res); -#endif - - i += width; - k += 4; - } - - // col%2 -> only set one of the two values -#ifndef USE_ROCM - half2 res3 = {}; - if (col % 2 == 0) { - res3.x = res; - } else { - res3.y = res; - } -#else - __half2 res3; - res3.x = __half_as_ushort(__float2half(0)); - res3.y = __half_as_ushort(__float2half(0)); - if (col % 2 == 0) { - res3.x = __half_as_ushort(res); - } else { - res3.y = __half_as_ushort(res); - } -#endif - -#ifndef USE_ROCM - atomicAdd(&mul[b * width / 2 + col / 2], res3); -#else - int tmp_addr = b * width / 2 + col / 2; - atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x))); - atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y))); -#endif - } -} - -} // namespace squeezellm -} // namespace vllm - -// 4-bit matvec kernel (LUT-based) -void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul, - torch::Tensor lookup_table) { - int height = mat.size(0); - int width = mat.size(1); - - int batch = vec.size(0); - int vec_height = vec.size(1); - - dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4, - (width + BLOCKWIDTH - 1) / BLOCKWIDTH); - dim3 threads(BLOCKWIDTH); - - const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - vllm::squeezellm::NUQ4MatMulKernel<<>>( -#ifndef USE_ROCM - (half2*)vec.data_ptr(), -#else - (__half2*)vec.data_ptr(), -#endif - mat.data_ptr(), -#ifndef USE_ROCM - (half2*)mul.data_ptr(), - (__half*)lookup_table.data_ptr(), -#else - (float2*)mul.data_ptr(), - (__half*)lookup_table.data_ptr(), -#endif - height, width, batch, vec_height); -} - -#undef BLOCKWIDTH -#undef BLOCKHEIGHT4 diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 6d1f53b75f4e2..07b14e7a6ff63 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -202,6 +202,31 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("cutlass_scaled_mm_supports_fp8", &cutlass_scaled_mm_supports_fp8); ops.impl("cutlass_scaled_mm_supports_fp8", torch::kCUDA, &cutlass_scaled_mm_supports_fp8); + // Mamba selective scan kernel + ops.def( + "selective_scan_fwd(Tensor! u, Tensor! delta," + "Tensor! A, Tensor! B, Tensor! C," + "Tensor? D_, Tensor? z_, Tensor? delta_bias_," + "bool delta_softplus," + "Tensor? index_, Tensor? x) -> Tensor[]"); + ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); + + ops.def( + "causal_conv1d_update(Tensor! x," + "Tensor! conv_state," + "Tensor! weight," + "Tensor? bias_," + "bool silu_activation) -> Tensor"); + ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); + + ops.def( + "causal_conv1d_fwd(Tensor! x, Tensor! weight," + "Tensor? bias_," + "Tensor? seq_idx_," + "Tensor? initial_states_," + "Tensor? final_states_out_," + "bool silu_activation) -> Tensor"); + ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); #endif // Quantized GEMM for GPTQ. @@ -212,12 +237,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"); ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); - // Quantized GEMM for SqueezeLLM. - ops.def( - "squeezellm_gemm(Tensor vec, Tensor mat, Tensor! mul, Tensor " - "lookup_table) -> ()"); - ops.impl("squeezellm_gemm", torch::kCUDA, &squeezellm_gemm); - // Compute FP8 quantized tensor for given scaling factor. ops.def( "static_scaled_fp8_quant(Tensor! out, Tensor input, Tensor scale) -> ()"); diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 95a9be7806633..6687929c0bebe 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -12,4 +12,4 @@ torch py-cpuinfo transformers mistral_common >= 1.3.4 -openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args +openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args \ No newline at end of file diff --git a/docs/source/community/meetups.rst b/docs/source/community/meetups.rst index 3b01b109ebf2c..a3962e96e7913 100644 --- a/docs/source/community/meetups.rst +++ b/docs/source/community/meetups.rst @@ -5,6 +5,7 @@ vLLM Meetups We host regular meetups in San Francisco Bay Area every 2 months. We will share the project updates from the vLLM team and have guest speakers from the industry to share their experience and insights. Please find the materials of our previous meetups below: +- `The sixth vLLM meetup `__, with NVIDIA, September 9th 2024. `[Slides] `__ - `The fifth vLLM meetup `__, with AWS, July 24th 2024. `[Slides] `__ - `The fourth vLLM meetup `__, with Cloudflare and BentoML, June 11th 2024. `[Slides] `__ - `The third vLLM meetup `__, with Roblox, April 2nd 2024. `[Slides] `__ diff --git a/docs/source/dev/profiling/profiling_index.rst b/docs/source/dev/profiling/profiling_index.rst index af3c78c3b5a55..e22d547293445 100644 --- a/docs/source/dev/profiling/profiling_index.rst +++ b/docs/source/dev/profiling/profiling_index.rst @@ -17,14 +17,28 @@ Traces can be visualized using https://ui.perfetto.dev/. .. tip:: Only send a few requests through vLLM when profiling, as the traces can get quite large. Also, no need to untar the traces, they can be viewed directly. - -Example commands: + +.. tip:: + + To stop the profiler - it flushes out all the profile trace files to the directory. This takes time, for example for about 100 requests worth of data for a llama 70b, it takes about 10 minutes to flush out on a H100. + Set the env variable VLLM_RPC_GET_DATA_TIMEOUT_MS to a big number before you start the server. Say something like 30 minutes. + ``export VLLM_RPC_GET_DATA_TIMEOUT_MS=1800000`` + +Example commands and usage: +=========================== + +Offline Inference: +------------------ + +Refer to `examples/offline_inference_with_profiler.py `_ for an example. + OpenAI Server: +-------------- .. code-block:: bash - VLLM_TORCH_PROFILER_DIR=/mnt/traces/ python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B + VLLM_TORCH_PROFILER_DIR=./vllm_profile python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B benchmark_serving.py: diff --git a/docs/source/getting_started/debugging.rst b/docs/source/getting_started/debugging.rst index 117a9dd666481..31ecca1332e5d 100644 --- a/docs/source/getting_started/debugging.rst +++ b/docs/source/getting_started/debugging.rst @@ -21,7 +21,7 @@ If you have already taken care of the above issues, but the vLLM instance still With more logging, hopefully you can find the root cause of the issue. -If it crashes, and the error trace shows somewhere around ``self.graph.replay()`` in ``vllm/worker/model_runner.py``, it is a cuda error inside cudagraph. To know the particular cuda operation that causes the error, you can add ``--enforce-eager`` to the command line, or ``enforce_eager=True`` to the ``LLM`` class, to disable the cudagraph optimization. This way, you can locate the exact cuda operation that causes the error. +If it crashes, and the error trace shows somewhere around ``self.graph.replay()`` in ``vllm/worker/model_runner.py``, it is a cuda error inside cudagraph. To know the particular cuda operation that causes the error, you can add ``--enforce-eager`` to the command line, or ``enforce_eager=True`` to the :class:`~vllm.LLM` class, to disable the cudagraph optimization. This way, you can locate the exact cuda operation that causes the error. Here are some common issues that can cause hangs: diff --git a/docs/source/getting_started/quickstart.rst b/docs/source/getting_started/quickstart.rst index 89bdc247c5e8e..80b19ac672936 100644 --- a/docs/source/getting_started/quickstart.rst +++ b/docs/source/getting_started/quickstart.rst @@ -24,7 +24,9 @@ Offline Batched Inference We first show an example of using vLLM for offline batched inference on a dataset. In other words, we use vLLM to generate texts for a list of input prompts. -Import ``LLM`` and ``SamplingParams`` from vLLM. The ``LLM`` class is the main class for running offline inference with vLLM engine. The ``SamplingParams`` class specifies the parameters for the sampling process. +Import :class:`~vllm.LLM` and :class:`~vllm.SamplingParams` from vLLM. +The :class:`~vllm.LLM` class is the main class for running offline inference with vLLM engine. +The :class:`~vllm.SamplingParams` class specifies the parameters for the sampling process. .. code-block:: python @@ -42,7 +44,7 @@ Define the list of input prompts and the sampling parameters for generation. The ] sampling_params = SamplingParams(temperature=0.8, top_p=0.95) -Initialize vLLM's engine for offline inference with the ``LLM`` class and the `OPT-125M model `_. The list of supported models can be found at :ref:`supported models `. +Initialize vLLM's engine for offline inference with the :class:`~vllm.LLM` class and the `OPT-125M model `_. The list of supported models can be found at :ref:`supported models `. .. code-block:: python diff --git a/docs/source/getting_started/tpu-installation.rst b/docs/source/getting_started/tpu-installation.rst index 31ae30ad302b3..217028839e347 100644 --- a/docs/source/getting_started/tpu-installation.rst +++ b/docs/source/getting_started/tpu-installation.rst @@ -56,9 +56,10 @@ First, install the dependencies: $ pip uninstall torch torch-xla -y $ # Install PyTorch and PyTorch XLA. - $ export DATE="+20240808" - $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly${DATE}-cp310-cp310-linux_x86_64.whl - $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly${DATE}-cp310-cp310-linux_x86_64.whl + $ export DATE="20240828" + $ export TORCH_VERSION="2.5.0" + $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl + $ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl $ # Install JAX and Pallas. $ pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html diff --git a/docs/source/models/lora.rst b/docs/source/models/lora.rst index f08773fe59d92..b3821ebdfceca 100644 --- a/docs/source/models/lora.rst +++ b/docs/source/models/lora.rst @@ -107,3 +107,55 @@ The following is an example request "max_tokens": 7, "temperature": 0 }' | jq + + +Dynamically serving LoRA Adapters +--------------------------------- + +In addition to serving LoRA adapters at server startup, the vLLM server now supports dynamically loading and unloading +LoRA adapters at runtime through dedicated API endpoints. This feature can be particularly useful when the flexibility +to change models on-the-fly is needed. + +Note: Enabling this feature in production environments is risky as user may participate model adapter management. + +To enable dynamic LoRA loading and unloading, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING` +is set to `True`. When this option is enabled, the API server will log a warning to indicate that dynamic loading is active. + +.. code-block:: bash + + export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True + + +Loading a LoRA Adapter: + +To dynamically load a LoRA adapter, send a POST request to the `/v1/load_lora_adapter` endpoint with the necessary +details of the adapter to be loaded. The request payload should include the name and path to the LoRA adapter. + +Example request to load a LoRA adapter: + +.. code-block:: bash + + curl -X POST http://localhost:8000/v1/load_lora_adapter \ + -H "Content-Type: application/json" \ + -d '{ + "lora_name": "sql_adapter", + "lora_path": "/path/to/sql-lora-adapter" + }' + +Upon a successful request, the API will respond with a 200 OK status code. If an error occurs, such as if the adapter +cannot be found or loaded, an appropriate error message will be returned. + +Unloading a LoRA Adapter: + +To unload a LoRA adapter that has been previously loaded, send a POST request to the `/v1/unload_lora_adapter` endpoint +with the name or ID of the adapter to be unloaded. + +Example request to unload a LoRA adapter: + +.. code-block:: bash + + curl -X POST http://localhost:8000/v1/unload_lora_adapter \ + -H "Content-Type: application/json" \ + -d '{ + "lora_name": "sql_adapter" + }' diff --git a/docs/source/models/spec_decode.rst b/docs/source/models/spec_decode.rst index d3c196faff25d..50468f25b922a 100644 --- a/docs/source/models/spec_decode.rst +++ b/docs/source/models/spec_decode.rst @@ -161,6 +161,46 @@ A variety of speculative models of this type are available on HF hub: * `granite-7b-instruct-accelerator `_ * `granite-20b-code-instruct-accelerator `_ +Lossless guarantees of Speculative Decoding +------------------------------------------- +In vLLM, speculative decoding aims to enhance inference efficiency while maintaining accuracy. This section addresses the lossless guarantees of +speculative decoding, breaking down the guarantees into three key areas: + +1. **Theoretical Losslessness** + - Speculative decoding sampling is theoretically lossless up to the precision limits of hardware numerics. Floating-point errors might + cause slight variations in output distributions, as discussed + in `Accelerating Large Language Model Decoding with Speculative Sampling `_ + +2. **Algorithmic Losslessness** + - vLLMā€™s implementation of speculative decoding is algorithmically validated to be lossless. Key validation tests include: + + - **Rejection Sampler Convergence**: Ensures that samples from vLLMā€™s rejection sampler align with the target + distribution. `View Test Code `_ + + - **Greedy Sampling Equality**: Confirms that greedy sampling with speculative decoding matches greedy sampling + without it. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler, + provides a lossless guarantee. Almost all of the tests in `this directory `_ + verify this property using `this assertion implementation `_ + +3. **vLLM Logprob Stability** + - vLLM does not currently guarantee stable token log probabilities (logprobs). This can result in different outputs for the + same request across runs. For more details, see the FAQ section + titled *Can the output of a prompt vary across runs in vLLM?* in the `FAQs <../serving/faq.rst>`_. + + +**Conclusion** + +While vLLM strives to ensure losslessness in speculative decoding, variations in generated outputs with and without speculative decoding +can occur due to following factors: + +- **Floating-Point Precision**: Differences in hardware numerical precision may lead to slight discrepancies in the output distribution. + +- **Batch Size and Numerical Stability**: Changes in batch size may cause variations in logprobs and output probabilities, potentially + due to non-deterministic behavior in batched operations or numerical instability. + +**Mitigation Strategies** + +For mitigation strategies, please refer to the FAQ entry *Can the output of a prompt vary across runs in vLLM?* in the `FAQs <../serving/faq.rst>`_. Resources for vLLM contributors ------------------------------- diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 223c68b40766e..1bb3a448f2c92 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -51,6 +51,10 @@ Decoder-only Language Models - DeciLM - :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc. - + * - :code:`ExaoneForCausalLM` + - EXAONE-3 + - :code:`LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. + - āœ…ļøŽ * - :code:`FalconForCausalLM` - Falcon - :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc. @@ -143,6 +147,10 @@ Decoder-only Language Models - Phi-3-Small - :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc. - + * - :code:`PhiMoEForCausalLM` + - Phi-3.5-MoE + - :code:`microsoft/Phi-3.5-MoE-instruct`, etc. + - * - :code:`PersimmonForCausalLM` - Persimmon - :code:`adept/persimmon-8b-base`, :code:`adept/persimmon-8b-chat`, etc. @@ -186,12 +194,12 @@ Multimodal Language Models * - Architecture - Models - - Supported Modalities + - Modalities - Example HuggingFace Models - :ref:`LoRA ` * - :code:`Blip2ForConditionalGeneration` - BLIP-2 - - Image + - Image\ :sup:`E` - :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc. - * - :code:`ChameleonForConditionalGeneration` @@ -206,40 +214,48 @@ Multimodal Language Models - * - :code:`InternVLChatModel` - InternVL2 - - Image + - Image\ :sup:`E+` - :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc. - * - :code:`LlavaForConditionalGeneration` - LLaVA-1.5 - - Image + - Image\ :sup:`E+` - :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc. - * - :code:`LlavaNextForConditionalGeneration` - LLaVA-NeXT - - Image + - Image\ :sup:`E+` - :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - + * - :code:`MiniCPMV` + - MiniCPM-V + - Image\ :sup:`+` + - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. + - * - :code:`PaliGemmaForConditionalGeneration` - PaliGemma - - Image + - Image\ :sup:`E` - :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc. - * - :code:`Phi3VForCausalLM` - Phi-3-Vision, Phi-3.5-Vision - - Image + - Image\ :sup:`E+` - :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. - - * - :code:`MiniCPMV` - - MiniCPM-V - - Image - - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. + * - :code:`QWenLMHeadModel` + - Qwen-VL + - Image\ :sup:`E` + - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - * - :code:`UltravoxModel` - Ultravox - - Audio + - Audio\ :sup:`E+` - :code:`fixie-ai/ultravox-v0_3` - +| :sup:`E` Pre-computed embeddings can be inputted for this modality. +| :sup:`+` Multiple items can be inputted per text prompt for this modality. + .. note:: For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index 236e37b51d470..08db891665044 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -9,26 +9,23 @@ This document shows you how to run and serve these models using vLLM. .. important:: We are actively iterating on VLM support. Expect breaking changes to VLM usage and development in upcoming releases without prior deprecation. - Currently, the support for vision language models on vLLM has the following limitations: - - * Only single image input is supported per text prompt. - We are continuously improving user & developer experience for VLMs. Please `open an issue on GitHub `_ if you have any feedback or feature requests. -Offline Batched Inference -------------------------- +Offline Inference +----------------- + +Single-image input +^^^^^^^^^^^^^^^^^^ -To initialize a VLM, the aforementioned arguments must be passed to the ``LLM`` class for instantiating the engine. +The :class:`~vllm.LLM` class can be instantiated in much the same way as language-only models. .. code-block:: python llm = LLM(model="llava-hf/llava-1.5-7b-hf") -.. important:: +.. note:: We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow - the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that - internally for each model. - + the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model. To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`: @@ -86,61 +83,117 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptI A code example can be found in `examples/offline_inference_vision_language.py `_. +Multi-image input +^^^^^^^^^^^^^^^^^ -Online OpenAI Vision API Compatible Inference ----------------------------------------------- +Multi-image input is only supported for a subset of VLMs, as shown :ref:`here `. -You can serve vision language models with vLLM's HTTP server that is compatible with `OpenAI Vision API `_. +To enable multiple multi-modal items per text prompt, you have to set ``limit_mm_per_prompt`` for the :class:`~vllm.LLM` class. -.. note:: - Currently, vLLM supports only **single** ``image_url`` input per ``messages``. Support for multi-image inputs will be - added in the future. +.. code-block:: python -Below is an example on how to launch the same ``llava-hf/llava-1.5-7b-hf`` with vLLM API server. + llm = LLM( + model="microsoft/Phi-3.5-vision-instruct", + trust_remote_code=True, # Required to load Phi-3.5-vision + max_model_len=4096, # Otherwise, it may not fit in smaller GPUs + limit_mm_per_prompt={"image": 2}, # The maximum number to accept + ) -.. important:: - Since OpenAI Vision API is based on `Chat `_ API, a chat template - is **required** to launch the API server if the model's tokenizer does not come with one. In this example, we use the - HuggingFace Llava chat template that you can find in the example folder `here `_. +Instead of passing in a single image, you can pass in a list of images. + +.. code-block:: python + + # Refer to the HuggingFace repo for the correct format to use + prompt = "<|user|>\n\n\nWhat is the content of each image?<|end|>\n<|assistant|>\n" + + # Load the images using PIL.Image + image1 = PIL.Image.open(...) + image2 = PIL.Image.open(...) + + outputs = llm.generate({ + "prompt": prompt, + "multi_modal_data": { + "image": [image1, image2] + }, + }) + + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + +A code example can be found in `examples/offline_inference_vision_language_multi_image.py `_. + +Online Inference +---------------- + +OpenAI Vision API +^^^^^^^^^^^^^^^^^ + +You can serve vision language models with vLLM's HTTP server that is compatible with `OpenAI Vision API `_. + +Below is an example on how to launch the same ``microsoft/Phi-3.5-vision-instruct`` with vLLM's OpenAI-compatible API server. .. code-block:: bash - vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja + vllm serve microsoft/Phi-3.5-vision-instruct --max-model-len 4096 \ + --trust-remote-code --limit-mm-per-prompt image=2 .. important:: - We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow - the above snippet. Specifically, ``image_feature_size`` is no longer required to be specified as we now calculate that - internally for each model. + Since OpenAI Vision API is based on `Chat Completions `_ API, + a chat template is **required** to launch the API server. + + Although Phi-3.5-Vision comes with a chat template, for other models you may have to provide one if the model's tokenizer does not come with it. + The chat template can be inferred based on the documentation on the model's HuggingFace repo. + For example, LLaVA-1.5 (``llava-hf/llava-1.5-7b-hf``) requires a chat template that can be found `here `_. To consume the server, you can use the OpenAI client like in the example below: .. code-block:: python from openai import OpenAI + openai_api_key = "EMPTY" openai_api_base = "http://localhost:8000/v1" + client = OpenAI( api_key=openai_api_key, base_url=openai_api_base, ) + + # Single-image input inference + image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + chat_response = client.chat.completions.create( - model="llava-hf/llava-1.5-7b-hf", + model="microsoft/Phi-3.5-vision-instruct", messages=[{ "role": "user", "content": [ # NOTE: The prompt formatting with the image token `` is not needed # since the prompt will be processed automatically by the API server. - {"type": "text", "text": "What's in this image?"}, - { - "type": "image_url", - "image_url": { - "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", - }, - }, + {"type": "text", "text": "Whatā€™s in this image?"}, + {"type": "image_url", "image_url": {"url": image_url}}, ], }], ) - print("Chat response:", chat_response) + print("Chat completion output:", chat_response.choices[0].message.content) + + # Multi-image input inference + image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg" + image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg" + + chat_response = client.chat.completions.create( + model="microsoft/Phi-3.5-vision-instruct", + messages=[{ + "role": "user", + "content": [ + {"type": "text", "text": "What are the animals in these images?"}, + {"type": "image_url", "image_url": {"url": image_url_duck}}, + {"type": "image_url", "image_url": {"url": image_url_lion}}, + ], + }], + ) + print("Chat completion output:", chat_response.choices[0].message.content) + A full code example can be found in `examples/openai_vision_api_client.py `_. diff --git a/docs/source/performance_benchmark/benchmarks.rst b/docs/source/performance_benchmark/benchmarks.rst index 9a23aab10d03d..e5c8d6a55de63 100644 --- a/docs/source/performance_benchmark/benchmarks.rst +++ b/docs/source/performance_benchmark/benchmarks.rst @@ -20,4 +20,4 @@ The performance benchmarks and nightly benchmarks can be triggered by submitting .. note:: - Please refer to `vLLM performance benchmark descriptions `_ and `vLLM nightly benchmark descriptions `_ for detailed descriptions on benchmark environment, workload and metrics. + Please refer to `vLLM performance benchmark descriptions `_ and `vLLM nightly benchmark descriptions `_ for detailed descriptions on benchmark environment, workload and metrics. diff --git a/docs/source/quantization/auto_awq.rst b/docs/source/quantization/auto_awq.rst index bbbb9aee78b3c..8eb6fa2f4cbe1 100644 --- a/docs/source/quantization/auto_awq.rst +++ b/docs/source/quantization/auto_awq.rst @@ -19,27 +19,31 @@ You can quantize your own models by installing AutoAWQ or picking one of the `40 $ pip install autoawq -After installing AutoAWQ, you are ready to quantize a model. Here is an example of how to quantize Vicuna 7B v1.5: +After installing AutoAWQ, you are ready to quantize a model. Here is an example of how to quantize `mistralai/Mistral-7B-Instruct-v0.2`: .. code-block:: python from awq import AutoAWQForCausalLM from transformers import AutoTokenizer - - model_path = 'lmsys/vicuna-7b-v1.5' - quant_path = 'vicuna-7b-v1.5-awq' + + model_path = 'mistralai/Mistral-7B-Instruct-v0.2' + quant_path = 'mistral-instruct-v0.2-awq' quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" } - + # Load model - model = AutoAWQForCausalLM.from_pretrained(model_path, **{"low_cpu_mem_usage": True}) + model = AutoAWQForCausalLM.from_pretrained( + model_path, **{"low_cpu_mem_usage": True, "use_cache": False} + ) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) - + # Quantize model.quantize(tokenizer, quant_config=quant_config) - + # Save quantized model model.save_quantized(quant_path) tokenizer.save_pretrained(quant_path) + + print(f'Model is quantized and saved at "{quant_path}"') To run an AWQ model with vLLM, you can use `TheBloke/Llama-2-7b-Chat-AWQ `_ with the following command: diff --git a/docs/source/quantization/supported_hardware.rst b/docs/source/quantization/supported_hardware.rst index 6341b583f0cfe..ea587e0525a74 100644 --- a/docs/source/quantization/supported_hardware.rst +++ b/docs/source/quantization/supported_hardware.rst @@ -119,17 +119,6 @@ The table below shows the compatibility of various quantization implementations - āœ— - āœ— - āœ— - * - SqueezeLLM - - āœ…ļøŽ - - āœ…ļøŽ - - āœ…ļøŽ - - āœ…ļøŽ - - āœ…ļøŽ - - āœ— - - āœ— - - āœ— - - āœ— - - āœ— Notes: ^^^^^^ diff --git a/docs/source/serving/faq.rst b/docs/source/serving/faq.rst index 7b0374be8adff..9e858e612c8bf 100644 --- a/docs/source/serving/faq.rst +++ b/docs/source/serving/faq.rst @@ -10,3 +10,22 @@ A: Assuming that you're referring to using OpenAI compatible server to serve mul Q: Which model to use for offline inference embedding? A: If you want to use an embedding model, try: https://huggingface.co/intfloat/e5-mistral-7b-instruct. Instead models, such as Llama-3-8b, Mistral-7B-Instruct-v0.3, are generation models rather than an embedding model + +---------------------------------------- + + Q: Can the output of a prompt vary across runs in vLLM? + +A: Yes, it can. vLLM does not guarantee stable log probabilities (logprobs) for the output tokens. Variations in logprobs may occur due to +numerical instability in Torch operations or non-deterministic behavior in batched Torch operations when batching changes. For more details, +see the `Numerical Accuracy section `_. + +In vLLM, the same requests might be batched differently due to factors such as other concurrent requests, +changes in batch size, or batch expansion in speculative decoding. These batching variations, combined with numerical instability of Torch operations, +can lead to slightly different logit/logprob values at each step. Such differences can accumulate, potentially resulting in +different tokens being sampled. Once a different token is sampled, further divergence is likely. + +**Mitigation Strategies** + +- For improved stability and reduced variance, use `float32`. Note that this will require more memory. +- If using `bfloat16`, switching to `float16` can also help. +- Using request seeds can aid in achieving more stable generation for temperature > 0, but discrepancies due to precision differences may still occur. diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index a06c30d9c48c6..eb4ea0fb5655e 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -110,14 +110,90 @@ directory [here](https://github.com/vllm-project/vllm/tree/main/examples/) :func: create_parser_for_docs :prog: vllm serve ``` +## Tool Calling in the Chat Completion API +### Named Function Calling +vLLM supports only named function calling in the chat completion API by default. It does so using Outlines, so this is +enabled by default, and will work with any supported model. You are guaranteed a validly-parsable function call - not a +high-quality one. + +To use a named function, you need to define the functions in the `tools` parameter of the chat completion request, and +specify the `name` of one of the tools in the `tool_choice` parameter of the chat completion request. + +### Config file + +The `serve` module can also accept arguments from a config file in +`yaml` format. The arguments in the yaml must be specified using the +long form of the argument outlined [here](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#command-line-arguments-for-the-server): + +For example: + +```yaml +# config.yaml + +host: "127.0.0.1" +port: 6379 +uvicorn-log-level: "info" +``` + +```bash +$ vllm serve SOME_MODEL --config config.yaml +``` +--- +**NOTE** +In case an argument is supplied using command line and the config file, the value from the commandline will take precedence. +The order of priorities is `command line > config file values > defaults`. + +--- ## Tool calling in the chat completion API vLLM supports only named function calling in the chat completion API. The `tool_choice` options `auto` and `required` are **not yet supported** but on the roadmap. -To use a named function you need to define the function in the `tools` parameter and call it in the `tool_choice` parameter. - -It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. **This may change in the future.** +It is the callers responsibility to prompt the model with the tool information, vLLM will not automatically manipulate the prompt. vLLM will use guided decoding to ensure the response matches the tool parameter object defined by the JSON schema in the `tools` parameter. -Please refer to the OpenAI API reference documentation for more information. + +### Automatic Function Calling +To enable this feature, you should set the following flags: +* `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it +deems appropriate. +* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral`. Additional tool parsers +will continue to be added in the future. +* `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages +that contain previously generated tool calls. Hermes and Mistral models have tool-compatible chat templates in their +`tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat +template configured in the `tokenizer_config.json`. In this case, it will be used per the `transformers` specification. More on this [here](https://huggingface.co/docs/transformers/en/chat_templating#why-do-some-models-have-multiple-templates) +from HuggingFace; and you can find an example of this in a `tokenizer_config.json` [here](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/tokenizer_config.json) + +If your favorite tool-calling model is not supported, please feel free to contribute a parser & tool use chat template! + +#### Hermes Models +All Nous Research Hermes-series models newer than Hermes 2 Pro should be supported. +* `NousResearch/Hermes-2-Pro-*` +* `NousResearch/Hermes-2-Theta-*` +* `NousResearch/Hermes-3-*` + + +_Note that the Hermes 2 **Theta** models are known to have degraded tool call quality & capabilities due to the merge +step in their creation_. + +Flags: `--tool-call-parser hermes` + +#### Mistral Models +Supported models: +* `mistralai/Mistral-7B-Instruct-v0.3` (confirmed) +* Additional mistral function-calling models are compatible as well. + +Known issues: +1. Mistral 7B struggles to generate parallel tool calls correctly. +2. Mistral's `tokenizer_config.json` chat template requires tool call IDs that are exactly 9 digits, which is +much shorter than what vLLM generates. Since an exception is thrown when this condition +is not met, the following additional chat templates are provided: + +* `examples/tool_chat_template_mistral.jinja` - this is the "official" Mistral chat template, but tweaked so that +it works with vLLM's tool call IDs (provided `tool_call_id` fields are truncated to the last 9 digits) +* `examples/tool_chat_template_mistral_parallel.jinja` - this is a "better" version that adds a tool-use system prompt +when tools are provided, that results in much better reliability when working with parallel tool calling. + + +Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja` diff --git a/examples/fp8/README.md b/examples/fp8/README.md index 84ad76c71862e..181c36558fcff 100644 --- a/examples/fp8/README.md +++ b/examples/fp8/README.md @@ -62,7 +62,7 @@ This script evaluates the inference throughput of language models using various python3 benchmarks/benchmark_throughput.py --help usage: benchmark_throughput.py [-h] [--backend {vllm,hf,mii}] [--dataset DATASET] [--input-len INPUT_LEN] [--output-len OUTPUT_LEN] [--model MODEL] - [--tokenizer TOKENIZER] [--quantization {awq,gptq,squeezellm,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N] + [--tokenizer TOKENIZER] [--quantization {awq,gptq,None}] [--tensor-parallel-size TENSOR_PARALLEL_SIZE] [--n N] [--use-beam-search] [--num-prompts NUM_PROMPTS] [--seed SEED] [--hf-max-batch-size HF_MAX_BATCH_SIZE] [--trust-remote-code] [--max-model-len MAX_MODEL_LEN] [--dtype {auto,half,float16,bfloat16,float,float32}] [--enforce-eager] [--kv-cache-dtype {auto,fp8}] [--quantization-param-path KV_CACHE_quantization_param_path] @@ -76,7 +76,7 @@ optional arguments: --output-len OUTPUT_LEN Output length for each request. Overrides the output length from the dataset. --model MODEL --tokenizer TOKENIZER - --quantization {awq,gptq,squeezellm,None}, -q {awq,gptq,squeezellm,None} + --quantization {awq,gptq,None}, -q {awq,gptq,None} --tensor-parallel-size TENSOR_PARALLEL_SIZE, -tp TENSOR_PARALLEL_SIZE --n N Number of generated sequences per prompt. --use-beam-search diff --git a/examples/offline_inference_audio_language.py b/examples/offline_inference_audio_language.py index 56ce8646c20c9..1c6ac06123bbb 100644 --- a/examples/offline_inference_audio_language.py +++ b/examples/offline_inference_audio_language.py @@ -11,25 +11,33 @@ from vllm.assets.audio import AudioAsset from vllm.utils import FlexibleArgumentParser -# Input audio and question -audio_and_sample_rate = AudioAsset("mary_had_lamb").audio_and_sample_rate -question = "What is recited in the audio?" +audio_assets = [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] +question_per_audio_count = [ + "What is recited in the audio?", + "What sport and what nursery rhyme are referenced?" +] # Ultravox 0.3 -def run_ultravox(question): +def run_ultravox(question, audio_count): model_name = "fixie-ai/ultravox-v0_3" tokenizer = AutoTokenizer.from_pretrained(model_name) messages = [{ - 'role': 'user', - 'content': f"<|reserved_special_token_0|>\n{question}" + 'role': + 'user', + 'content': + "<|reserved_special_token_0|>\n" * audio_count + question }] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - llm = LLM(model=model_name) + llm = LLM(model=model_name, + enforce_eager=True, + enable_chunked_prefill=False, + max_model_len=8192, + limit_mm_per_prompt={"audio": audio_count}) stop_token_ids = None return llm, prompt, stop_token_ids @@ -44,7 +52,9 @@ def main(args): if model not in model_example_map: raise ValueError(f"Model type {model} is not supported.") - llm, prompt, stop_token_ids = model_example_map[model](question) + audio_count = args.num_audios + llm, prompt, stop_token_ids = model_example_map[model]( + question_per_audio_count[audio_count - 1], audio_count) # We set temperature to 0.2 so that outputs can be different # even when all prompts are identical when running batch inference. @@ -53,23 +63,18 @@ def main(args): stop_token_ids=stop_token_ids) assert args.num_prompts > 0 - if args.num_prompts == 1: - # Single inference - inputs = { - "prompt": prompt, - "multi_modal_data": { - "audio": audio_and_sample_rate - }, - } - - else: + inputs = { + "prompt": prompt, + "multi_modal_data": { + "audio": [ + asset.audio_and_sample_rate + for asset in audio_assets[:audio_count] + ] + }, + } + if args.num_prompts > 1: # Batch inference - inputs = [{ - "prompt": prompt, - "multi_modal_data": { - "audio": audio_and_sample_rate - }, - } for _ in range(args.num_prompts)] + inputs = [inputs] * args.num_prompts outputs = llm.generate(inputs, sampling_params=sampling_params) @@ -92,6 +97,11 @@ def main(args): type=int, default=1, help='Number of prompts to run.') + parser.add_argument("--num-audios", + type=int, + default=1, + choices=[1, 2], + help="Number of audio items per prompt.") args = parser.parse_args() main(args) diff --git a/examples/offline_inference_neuron.py b/examples/offline_inference_neuron.py index 5ecbbf020ab8b..2856be7c864ea 100644 --- a/examples/offline_inference_neuron.py +++ b/examples/offline_inference_neuron.py @@ -1,5 +1,12 @@ +import os + from vllm import LLM, SamplingParams +# creates XLA hlo graphs for all the context length buckets. +os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048" +# creates XLA hlo graphs for all the token gen buckets. +os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048" + # Sample prompts. prompts = [ "Hello, my name is", @@ -19,8 +26,8 @@ # Currently, this is a known limitation in continuous batching support # in transformers-neuronx. # TODO(liangfu): Support paged-attention in transformers-neuronx. - max_model_len=128, - block_size=128, + max_model_len=2048, + block_size=2048, # The device can be automatically detected when AWS Neuron SDK is installed. # The device argument can be either unspecified for automated detection, # or explicitly assigned. diff --git a/examples/offline_inference_neuron_int8_quantization.py b/examples/offline_inference_neuron_int8_quantization.py new file mode 100644 index 0000000000000..8ec17e3400953 --- /dev/null +++ b/examples/offline_inference_neuron_int8_quantization.py @@ -0,0 +1,50 @@ +import os + +from vllm import LLM, SamplingParams + +# creates XLA hlo graphs for all the context length buckets. +os.environ['NEURON_CONTEXT_LENGTH_BUCKETS'] = "128,512,1024,2048" +# creates XLA hlo graphs for all the token gen buckets. +os.environ['NEURON_TOKEN_GEN_BUCKETS'] = "128,512,1024,2048" +# Quantizes neuron model weight to int8 , +# The default config for quantization is int8 dtype. +os.environ['NEURON_QUANT_DTYPE'] = "s8" + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Create an LLM. +llm = LLM( + model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", + max_num_seqs=8, + # The max_model_len and block_size arguments are required to be same as + # max sequence length when targeting neuron device. + # Currently, this is a known limitation in continuous batching support + # in transformers-neuronx. + # TODO(liangfu): Support paged-attention in transformers-neuronx. + max_model_len=2048, + block_size=2048, + # The device can be automatically detected when AWS Neuron SDK is installed. + # The device argument can be either unspecified for automated detection, + # or explicitly assigned. + device="neuron", + quantization="neuron_quant", + override_neuron_config={ + "cast_logits_dtype": "bfloat16", + }, + tensor_parallel_size=2) +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 9a0e9d4bc5362..aa1580343aee7 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -159,6 +159,20 @@ def run_blip2(question): return llm, prompt, stop_token_ids +# Qwen +def run_qwen_vl(question): + + llm = LLM( + model="Qwen/Qwen-VL", + trust_remote_code=True, + max_num_seqs=5, + ) + + prompt = f"{question}Picture 1: \n" + stop_token_ids = None + return llm, prompt, stop_token_ids + + model_example_map = { "llava": run_llava, "llava-next": run_llava_next, @@ -169,6 +183,7 @@ def run_blip2(question): "minicpmv": run_minicpmv, "blip-2": run_blip2, "internvl_chat": run_internvl, + "qwen_vl": run_qwen_vl, } diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py new file mode 100644 index 0000000000000..dd84627b9dc58 --- /dev/null +++ b/examples/offline_inference_vision_language_multi_image.py @@ -0,0 +1,153 @@ +""" +This example shows how to use vLLM for running offline inference with +multi-image input on vision language models, using the chat template defined +by the model. +""" +from argparse import Namespace +from typing import List + +from transformers import AutoTokenizer + +from vllm import LLM, SamplingParams +from vllm.multimodal.utils import fetch_image +from vllm.utils import FlexibleArgumentParser + +QUESTION = "What is the content of each image?" +IMAGE_URLS = [ + "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg", + "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg", +] + + +def load_phi3v(question, image_urls: List[str]): + llm = LLM( + model="microsoft/Phi-3.5-vision-instruct", + trust_remote_code=True, + max_model_len=4096, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + placeholders = "\n".join(f"<|image_{i}|>" + for i, _ in enumerate(image_urls, start=1)) + prompt = f"<|user|>\n{placeholders}\n{question}<|end|>\n<|assistant|>\n" + stop_token_ids = None + return llm, prompt, stop_token_ids + + +def load_internvl(question, image_urls: List[str]): + model_name = "OpenGVLab/InternVL2-2B" + + llm = LLM( + model=model_name, + trust_remote_code=True, + max_num_seqs=5, + max_model_len=4096, + limit_mm_per_prompt={"image": len(image_urls)}, + ) + + placeholders = "\n".join(f"Image-{i}: \n" + for i, _ in enumerate(image_urls, start=1)) + messages = [{'role': 'user', 'content': f"{placeholders}\n{question}"}] + + tokenizer = AutoTokenizer.from_pretrained(model_name, + trust_remote_code=True) + prompt = tokenizer.apply_chat_template(messages, + tokenize=False, + add_generation_prompt=True) + + # Stop tokens for InternVL + # models variants may have different stop tokens + # please refer to the model card for the correct "stop words": + # https://huggingface.co/OpenGVLab/InternVL2-2B#service + stop_tokens = ["<|endoftext|>", "<|im_start|>", "<|im_end|>", "<|end|>"] + stop_token_ids = [tokenizer.convert_tokens_to_ids(i) for i in stop_tokens] + return llm, prompt, stop_token_ids + + +model_example_map = { + "phi3_v": load_phi3v, + "internvl_chat": load_internvl, +} + + +def run_generate(model, question: str, image_urls: List[str]): + llm, prompt, stop_token_ids = model_example_map[model](question, + image_urls) + + sampling_params = SamplingParams(temperature=0.0, + max_tokens=128, + stop_token_ids=stop_token_ids) + + outputs = llm.generate( + { + "prompt": prompt, + "multi_modal_data": { + "image": [fetch_image(url) for url in image_urls] + }, + }, + sampling_params=sampling_params) + + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + + +def run_chat(model: str, question: str, image_urls: List[str]): + llm, _, stop_token_ids = model_example_map[model](question, image_urls) + + sampling_params = SamplingParams(temperature=0.0, + max_tokens=128, + stop_token_ids=stop_token_ids) + + outputs = llm.chat([{ + "role": + "user", + "content": [ + { + "type": "text", + "text": question, + }, + *({ + "type": "image_url", + "image_url": { + "url": image_url + }, + } for image_url in image_urls), + ], + }], + sampling_params=sampling_params) + + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + + +def main(args: Namespace): + model = args.model_type + method = args.method + + if method == "generate": + run_generate(model, QUESTION, IMAGE_URLS) + elif method == "chat": + run_chat(model, QUESTION, IMAGE_URLS) + else: + raise ValueError(f"Invalid method: {method}") + + +if __name__ == "__main__": + parser = FlexibleArgumentParser( + description='Demo on using vLLM for offline inference with ' + 'vision language models that support multi-image input') + parser.add_argument('--model-type', + '-m', + type=str, + default="phi3_v", + choices=model_example_map.keys(), + help='Huggingface "model_type".') + parser.add_argument("--method", + type=str, + default="generate", + choices=["generate", "chat"], + help="The method to run in `vllm.LLM`.") + + args = parser.parse_args() + main(args) diff --git a/examples/offline_inference_with_profiler.py b/examples/offline_inference_with_profiler.py new file mode 100644 index 0000000000000..906c9502800d8 --- /dev/null +++ b/examples/offline_inference_with_profiler.py @@ -0,0 +1,33 @@ +import os + +from vllm import LLM, SamplingParams + +# enable torch profiler, can also be set on cmd line +os.environ["VLLM_TORCH_PROFILER_DIR"] = "./vllm_profile" + +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Create an LLM. +llm = LLM(model="facebook/opt-125m") + +llm.start_profile() + +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) + +llm.stop_profile() + +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/examples/openai_chat_completion_client_with_tools.py b/examples/openai_chat_completion_client_with_tools.py new file mode 100644 index 0000000000000..2bbe42b6bd2ef --- /dev/null +++ b/examples/openai_chat_completion_client_with_tools.py @@ -0,0 +1,162 @@ +""" +Set up this example by starting a vLLM OpenAI-compatible server with tool call +options enabled. For example: + +IMPORTANT: for mistral, you must use one of the provided mistral tool call +templates, or your own - the model default doesn't work for tool calls with vLLM +See the vLLM docs on OpenAI server & tool calling for more details. + +vllm serve --model mistralai/Mistral-7B-Instruct-v0.3 \ + --chat-template examples/tool_chat_template_mistral.jinja \ + --enable-auto-tool-choice --tool-call-parser mistral + +OR +vllm serve --model NousResearch/Hermes-2-Pro-Llama-3-8B \ + --chat-template examples/tool_chat_template_hermes.jinja \ + --enable-auto-tool-choice --tool-call-parser hermes +""" +import json + +from openai import OpenAI + +# Modify OpenAI's API key and API base to use vLLM's API server. +openai_api_key = "EMPTY" +openai_api_base = "http://localhost:8000/v1" + +client = OpenAI( + # defaults to os.environ.get("OPENAI_API_KEY") + api_key=openai_api_key, + base_url=openai_api_base, +) + +models = client.models.list() +model = models.data[0].id + +tools = [{ + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": + "string", + "description": + "The city to find the weather for, e.g. 'San Francisco'" + }, + "state": { + "type": + "string", + "description": + "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'" + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["city", "state", "unit"] + } + } +}] + +messages = [{ + "role": "user", + "content": "Hi! How are you doing today?" +}, { + "role": "assistant", + "content": "I'm doing well! How can I help you?" +}, { + "role": + "user", + "content": + "Can you tell me what the temperate will be in Dallas, in fahrenheit?" +}] + +chat_completion = client.chat.completions.create(messages=messages, + model=model, + tools=tools) + +print("Chat completion results:") +print(chat_completion) +print("\n\n") + +tool_calls_stream = client.chat.completions.create(messages=messages, + model=model, + tools=tools, + stream=True) + +chunks = [] +for chunk in tool_calls_stream: + chunks.append(chunk) + if chunk.choices[0].delta.tool_calls: + print(chunk.choices[0].delta.tool_calls[0]) + else: + print(chunk.choices[0].delta) + +arguments = [] +tool_call_idx = -1 +for chunk in chunks: + + if chunk.choices[0].delta.tool_calls: + tool_call = chunk.choices[0].delta.tool_calls[0] + + if tool_call.index != tool_call_idx: + if tool_call_idx >= 0: + print( + f"streamed tool call arguments: {arguments[tool_call_idx]}" + ) + tool_call_idx = chunk.choices[0].delta.tool_calls[0].index + arguments.append("") + if tool_call.id: + print(f"streamed tool call id: {tool_call.id} ") + + if tool_call.function: + if tool_call.function.name: + print(f"streamed tool call name: {tool_call.function.name}") + + if tool_call.function.arguments: + arguments[tool_call_idx] += tool_call.function.arguments + +if len(arguments): + print(f"streamed tool call arguments: {arguments[-1]}") + +print("\n\n") + +messages.append({ + "role": "assistant", + "tool_calls": chat_completion.choices[0].message.tool_calls +}) + + +# Now, simulate a tool call +def get_current_weather(city: str, state: str, unit: 'str'): + return ("The weather in Dallas, Texas is 85 degrees fahrenheit. It is " + "partly cloudly, with highs in the 90's.") + + +available_tools = {"get_current_weather": get_current_weather} + +completion_tool_calls = chat_completion.choices[0].message.tool_calls +for call in completion_tool_calls: + tool_to_call = available_tools[call.function.name] + args = json.loads(call.function.arguments) + result = tool_to_call(**args) + print(result) + messages.append({ + "role": "tool", + "content": result, + "tool_call_id": call.id, + "name": call.function.name + }) + +chat_completion_2 = client.chat.completions.create(messages=messages, + model=model, + tools=tools, + stream=False) +print("\n\n") +print(chat_completion_2) diff --git a/examples/openai_vision_api_client.py b/examples/openai_vision_api_client.py index be90394511f89..1ba702ef019e4 100644 --- a/examples/openai_vision_api_client.py +++ b/examples/openai_vision_api_client.py @@ -1,7 +1,13 @@ """An example showing how to use vLLM to serve VLMs. Launch the vLLM server with the following command: + +(single image inference with Llava) vllm serve llava-hf/llava-1.5-7b-hf --chat-template template_llava.jinja + +(multi-image inference with Phi-3.5-vision-instruct) +vllm serve microsoft/Phi-3.5-vision-instruct --max-model-len 4096 \ + --trust-remote-code --limit-mm-per-prompt image=2 """ import base64 @@ -21,9 +27,10 @@ models = client.models.list() model = models.data[0].id +# Single-image input inference image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" -# Use image url in the payload +## Use image url in the payload chat_completion_from_url = client.chat.completions.create( messages=[{ "role": @@ -46,10 +53,10 @@ ) result = chat_completion_from_url.choices[0].message.content -print(f"Chat completion output:{result}") +print("Chat completion output:", result) -# Use base64 encoded image in the payload +## Use base64 encoded image in the payload def encode_image_base64_from_url(image_url: str) -> str: """Encode an image retrieved from a remote url to base64 format.""" @@ -84,3 +91,36 @@ def encode_image_base64_from_url(image_url: str) -> str: result = chat_completion_from_base64.choices[0].message.content print(f"Chat completion output:{result}") + +# Multi-image input inference +image_url_duck = "https://upload.wikimedia.org/wikipedia/commons/d/da/2015_Kaczka_krzy%C5%BCowka_w_wodzie_%28samiec%29.jpg" +image_url_lion = "https://upload.wikimedia.org/wikipedia/commons/7/77/002_The_lion_king_Snyggve_in_the_Serengeti_National_Park_Photo_by_Giles_Laurent.jpg" +chat_completion_from_url = client.chat.completions.create( + messages=[{ + "role": + "user", + "content": [ + { + "type": "text", + "text": "What are the animals in these images?" + }, + { + "type": "image_url", + "image_url": { + "url": image_url_duck + }, + }, + { + "type": "image_url", + "image_url": { + "url": image_url_lion + }, + }, + ], + }], + model=model, + max_tokens=64, +) + +result = chat_completion_from_url.choices[0].message.content +print("Chat completion output:", result) diff --git a/examples/tool_chat_template_hermes.jinja b/examples/tool_chat_template_hermes.jinja new file mode 100644 index 0000000000000..0b0902c8e7497 --- /dev/null +++ b/examples/tool_chat_template_hermes.jinja @@ -0,0 +1,130 @@ +{%- macro json_to_python_type(json_spec) %} + {%- set basic_type_map = { + "string": "str", + "number": "float", + "integer": "int", + "boolean": "bool" +} %} + + {%- if basic_type_map[json_spec.type] is defined %} + {{- basic_type_map[json_spec.type] }} + {%- elif json_spec.type == "array" %} + {{- "list[" + json_to_python_type(json_spec|items) + "]" }} + {%- elif json_spec.type == "object" %} + {%- if json_spec.additionalProperties is defined %} + {{- "dict[str, " + json_to_python_type(json_spec.additionalProperties) + ']' }} + {%- else %} + {{- "dict" }} + {%- endif %} + {%- elif json_spec.type is iterable %} + {{- "Union[" }} + {%- for t in json_spec.type %} + {{- json_to_python_type({"type": t}) }} + {%- if not loop.last %} + {{- "," }} + {%- endif %} + {%- endfor %} + {{- "]" }} + {%- else %} + {{- "Any" }} + {%- endif %} +{%- endmacro %} + + +{{- bos_token }} +{{- "<|im_start|>system\nYou are a function calling AI model. You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. Here are the available tools: " }} +{%- if tools is iterable and tools | length > 0 %} + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- '{"type": "function", "function": ' }} + {{- '{"name": "' + tool.name + '", ' }} + {{- '"description": "' + tool.name + '(' }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- param_name + ": " + json_to_python_type(param_fields) }} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- ")" }} + {%- if tool.return is defined %} + {{- " -> " + json_to_python_type(tool.return) }} + {%- endif %} + {{- " - " + tool.description + "\n\n" }} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {%- if loop.first %} + {{- " Args:\n" }} + {%- endif %} + {{- " " + param_name + "(" + json_to_python_type(param_fields) + "): " + param_fields.description|trim }} + {%- endfor %} + {%- if tool.return is defined and tool.return.description is defined %} + {{- "\n Returns:\n " + tool.return.description }} + {%- endif %} + {{- '"' }} + {{- ', "parameters": ' }} + {%- if tool.parameters.properties | length == 0 %} + {{- "{}" }} + {%- else %} + {{- tool.parameters|tojson }} + {%- endif %} + {{- "}" }} + {%- if not loop.last %} + {{- "\n" }} + {%- endif %} + {%- endfor %} +{%- endif %} +{{- " " }} +{{- 'Use the following pydantic model json schema for each tool call you will make: {"properties": {"name": {"title": "Name", "type": "string"}, "arguments": {"title": "Arguments", "type": "object"}}, "required": ["name", "arguments"], "title": "FunctionCall", "type": "object"}} +' }} +{{- "For each function call return a json object with function name and arguments within XML tags as follows: +" }} +{{- " +" }} +{{- '{"name": , "arguments": } +' }} +{{- '<|im_end|>' }} +{%- for message in messages %} + {%- if message.role == "user" or message.role == "system" or (message.role == "assistant" and message.tool_calls is not defined) %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" and message.tool_calls is defined %} + {{- '<|im_start|>' + message.role }} + {%- for tool_call in message.tool_calls %} + {{- '\n\n' }} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '{' }} + {{- '"name": "' }} + {{- tool_call.name }} + {{- '"' }} + {%- if tool_call.arguments is defined %} + {{- ', ' }} + {{- '"arguments": ' }} + {{- tool_call.arguments|tojson }} + {%- endif %} + {{- '}' }} + {{- '\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>tool\n' }} + {%- endif %} + {{- '\n' }} + {{- message.content }} + {%- if not loop.last %} + {{- '\n\n' }} + {%- else %} + {{- '\n' }} + {%- endif %} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>' }} + {%- elif loop.last %} + {{- '<|im_end|>' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/examples/tool_chat_template_mistral.jinja b/examples/tool_chat_template_mistral.jinja new file mode 100644 index 0000000000000..49691f59c2f2c --- /dev/null +++ b/examples/tool_chat_template_mistral.jinja @@ -0,0 +1,86 @@ +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} +{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %} + +{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %} + {%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %} + {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif %} +{%- endfor %} + +{{- bos_token }} +{%- for message in loop_messages %} + {%- if message["role"] == "user" %} + {%- if tools is not none and (message == user_messages[-1]) %} + {{- "[AVAILABLE_TOOLS] [" }} + {%- for tool in tools %} + {%- set tool = tool.function %} + {{- '{"type": "function", "function": {' }} + {%- for key, val in tool.items() if key != "return" %} + {%- if val is string %} + {{- '"' + key + '": "' + val + '"' }} + {%- else %} + {{- '"' + key + '": ' + val|tojson }} + {%- endif %} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- "}}" }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- "[/AVAILABLE_TOOLS]" }} + {%- endif %} + {%- if loop.last and system_message is defined %} + {{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }} + {%- else %} + {{- "[INST] " + message["content"] + "[/INST]" }} + {%- endif %} + {%- elif message["role"] == "tool_calls" or message.tool_calls is defined %} + {%- if message.tool_calls is defined %} + {%- set tool_calls = message.tool_calls %} + {%- else %} + {%- set tool_calls = message.content %} + {%- endif %} + {{- "[TOOL_CALLS] [" }} + {%- for tool_call in tool_calls %} + {%- set out = tool_call.function|tojson %} + {{- out[:-1] }} + {%- if not tool_call.id is defined or tool_call.id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }} + {%- endif %} + {{- ', "id": "' + tool_call.id[-9:] + '"}' }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" + eos_token }} + {%- endif %} + {%- endfor %} + {%- elif message["role"] == "assistant" %} + {{- " " + message["content"] + eos_token }} + {%- elif message["role"] == "tool_results" or message["role"] == "tool" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }} + {%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }} + {%- endif %} + {{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }} + {%- else %} + {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %} diff --git a/examples/tool_chat_template_mistral_parallel.jinja b/examples/tool_chat_template_mistral_parallel.jinja new file mode 100644 index 0000000000000..a294cbfd026be --- /dev/null +++ b/examples/tool_chat_template_mistral_parallel.jinja @@ -0,0 +1,94 @@ +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} +{%- if tools is defined %} + {%- set parallel_tool_prompt = "You are a helpful assistant that can call tools. If you call one or more tools, format them in a single JSON array or objects, where each object is a tool call, not as separate objects outside of an array or multiple arrays. Use the format [{\"name\": tool call name, \"arguments\": tool call arguments}, additional tool calls] if you call more than one tool. If you call tools, do not attempt to interpret them or otherwise provide a response until you receive a tool call result that you can interpret for the user." %} + {%- if system_message is defined %} + {%- set system_message = parallel_tool_prompt + "\n\n" + system_message %} + {%- else %} + {%- set system_message = parallel_tool_prompt %} + {%- endif %} +{%- endif %} +{%- set user_messages = loop_messages | selectattr("role", "equalto", "user") | list %} + +{%- for message in loop_messages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %} + {%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %} + {{- raise_exception("After the optional system message, conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif %} +{%- endfor %} + +{{- bos_token }} +{%- for message in loop_messages %} + {%- if message["role"] == "user" %} + {%- if tools is not none and (message == user_messages[-1]) %} + {{- "[AVAILABLE_TOOLS] [" }} + {%- for tool in tools %} + {%- set tool = tool.function %} + {{- '{"type": "function", "function": {' }} + {%- for key, val in tool.items() if key != "return" %} + {%- if val is string %} + {{- '"' + key + '": "' + val + '"' }} + {%- else %} + {{- '"' + key + '": ' + val|tojson }} + {%- endif %} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- "}}" }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- "[/AVAILABLE_TOOLS]" }} + {%- endif %} + {%- if loop.last and system_message is defined %} + {{- "[INST] " + system_message + "\n\n" + message["content"] + "[/INST]" }} + {%- else %} + {{- "[INST] " + message["content"] + "[/INST]" }} + {%- endif %} + {%- elif message["role"] == "tool_calls" or message.tool_calls is defined %} + {%- if message.tool_calls is defined %} + {%- set tool_calls = message.tool_calls %} + {%- else %} + {%- set tool_calls = message.content %} + {%- endif %} + {{- "[TOOL_CALLS] [" }} + {%- for tool_call in tool_calls %} + {%- set out = tool_call.function|tojson %} + {{- out[:-1] }} + {%- if not tool_call.id is defined or tool_call.id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (1)" + tool_call.id) }} + {%- endif %} + {{- ', "id": "' + tool_call.id[-9:] + '"}' }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" + eos_token }} + {%- endif %} + {%- endfor %} + {%- elif message["role"] == "assistant" %} + {{- " " + message["content"] + eos_token }} + {%- elif message["role"] == "tool_results" or message["role"] == "tool" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {{- '[TOOL_RESULTS] {"content": ' + content|string + ", " }} + {%- if not message.tool_call_id is defined or message.tool_call_id|length < 9 %} + {{- raise_exception("Tool call IDs should be alphanumeric strings with length >= 9! (2)" + message.tool_call_id) }} + {%- endif %} + {{- '"call_id": "' + message.tool_call_id[-9:] + '"}[/TOOL_RESULTS]' }} + {%- else %} + {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %} diff --git a/requirements-adag.txt b/requirements-adag.txt deleted file mode 100644 index e77f90fb8f85d..0000000000000 --- a/requirements-adag.txt +++ /dev/null @@ -1,3 +0,0 @@ -# Dependencies for Ray accelerated DAG -cupy-cuda12x -ray >= 2.32 \ No newline at end of file diff --git a/requirements-common.txt b/requirements-common.txt index 61daf99819756..49a290317f818 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -9,7 +9,7 @@ tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. fastapi aiohttp -openai >= 1.0 # Ensure modern openai package (ensure types module present) +openai >= 1.40.0 # Ensure modern openai package (ensure types module present) uvicorn[standard] pydantic >= 2.8 # Required for OpenAI server. pillow # Required for image processing @@ -20,10 +20,11 @@ lm-format-enforcer == 0.10.6 outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0 typing_extensions >= 4.10 filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 +partial-json-parser # used for parsing partial JSON outputs pyzmq msgspec -librosa # Required for audio processing -soundfile # Required for audio processing gguf == 0.9.1 importlib_metadata mistral_common >= 1.3.4 +pyyaml +six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 diff --git a/requirements-mamba.txt b/requirements-mamba.txt deleted file mode 100644 index 1838e87d063da..0000000000000 --- a/requirements-mamba.txt +++ /dev/null @@ -1,3 +0,0 @@ -# Mamba dependencies -mamba-ssm>=1.2.2 -causal-conv1d>=1.2.0 diff --git a/requirements-test.txt b/requirements-test.txt index cdbc3e50cc9ec..44ba99fe84bd4 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,6 +1,3 @@ -# Needed for Ray accelerated DAG tests --r requirements-adag.txt - # testing pytest tensorizer>=2.9.0 @@ -11,12 +8,14 @@ pytest-shard # testing utils awscli -einops # required for MPT and qwen-vl +einops # required for MPT, qwen-vl and Mamba httpx +librosa # required for audio test peft requests -ray +ray[adag]>=2.35 sentence-transformers # required for embedding +soundfile # required for audio test compressed-tensors==0.4.0 # required for compressed-tensors timm # required for internvl test transformers_stream_generator # required for qwen-vl test @@ -30,4 +29,4 @@ aiohttp # quantization bitsandbytes==0.42.0 -buildkite-test-collector==0.1.8 \ No newline at end of file +buildkite-test-collector==0.1.8 diff --git a/requirements-tpu.txt b/requirements-tpu.txt index 5eb27b39eb623..4c606cf0a9105 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -4,4 +4,4 @@ # Dependencies for TPU # Currently, the TPU backend uses a nightly version of PyTorch XLA. # You can install the dependencies in Dockerfile.tpu. -ray +ray[default] diff --git a/setup.py b/setup.py index 21b0422c0f0bd..1e08a5bd70cd3 100644 --- a/setup.py +++ b/setup.py @@ -362,7 +362,8 @@ def get_vllm_version() -> str: version = find_version(get_path("vllm", "version.py")) if _no_device(): - version += "+empty" + if envs.VLLM_TARGET_DEVICE == "empty": + version += "+empty" elif _is_cuda(): cuda_version = str(get_nvcc_cuda_version()) if cuda_version != MAIN_CUDA_VERSION: @@ -501,6 +502,7 @@ def _read_requirements(filename: str) -> List[str]: ext_modules=ext_modules, extras_require={ "tensorizer": ["tensorizer>=2.9.0"], + "audio": ["librosa", "soundfile"] # Required for audio processing }, cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {}, package_data=package_data, diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 1211e6ba5aafc..9c34b2a13fd53 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -6,6 +6,7 @@ Run `pytest tests/models/test_chunked_prefill.py`. """ +from contextlib import nullcontext import pytest @@ -15,18 +16,6 @@ "facebook/opt-125m", "meta-llama/Llama-2-7b-hf", ] -E5M2_KV_MODELS = [ - "facebook/opt-125m", - "meta-llama/Llama-2-7b-chat-hf", -] -E4M3_KV_MODELS = [ - "meta-llama/Llama-2-7b-chat-hf", "nm-testing/Qwen2-1.5B-Instruct-FP8-K-V", - "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme" -] -KV_CACHE_QUANTIZATION_PATHS = { - "meta-llama/Llama-2-7b-chat-hf": - "./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json" -} @pytest.mark.parametrize("model", MODELS) @@ -77,10 +66,10 @@ def test_models( ) -@pytest.mark.parametrize("kv_cache_dtype,model", - [("fp8_e5m2", m) - for m in E5M2_KV_MODELS] + [("fp8_e4m3", m) - for m in E4M3_KV_MODELS]) +@pytest.mark.parametrize( + "kv_cache_dtype,model", + [("fp8_e4m3", + "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme")]) # Due to low-precision numerical divergence, we only test logprob of 4 tokens @pytest.mark.parametrize("max_tokens", [4]) @pytest.mark.parametrize("chunked_prefill_token_size", [4, 16]) @@ -103,27 +92,15 @@ def test_models_with_fp8_kv_cache( disable_async_output_proc: bool, ) -> None: """ - Only checks log probs match between chunked-prefill and - non-chunked-prefill version of vLLM model runner. - - This test is used when there is discrepancy in kernels - / numerics (e.g. when using lower-precision types like FP8). + Check output logprobs match between no_chunked_prefill and chunked_prefill + with fp8 kv cache. General fp8 kv-cache tests are covered in test_fp8.py, + so here we only check chunked prefill. """ NUM_LOG_PROBS = 8 - if model == "facebook/opt-125m": - pytest.skip( - "#7378: CUDA illegal memory access (undiagnosed) facebook/opt-125m" - ) - max_num_seqs = chunked_prefill_token_size max_num_batched_tokens = chunked_prefill_token_size - extra_kwargs = {} - if model in KV_CACHE_QUANTIZATION_PATHS: - extra_kwargs["quantization_param_path"] = KV_CACHE_QUANTIZATION_PATHS[ - model] - with vllm_runner( model, tensor_parallel_size=tensor_parallel_size, @@ -131,7 +108,6 @@ def test_models_with_fp8_kv_cache( max_num_seqs=max_num_seqs, kv_cache_dtype=kv_cache_dtype, disable_async_output_proc=disable_async_output_proc, - **extra_kwargs, ) as vllm_model: no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, NUM_LOG_PROBS) @@ -145,7 +121,6 @@ def test_models_with_fp8_kv_cache( max_num_seqs=max_num_seqs, kv_cache_dtype=kv_cache_dtype, disable_async_output_proc=disable_async_output_proc, - **extra_kwargs, ) as vllm_model: chunked_prefill_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, NUM_LOG_PROBS) @@ -156,3 +131,68 @@ def test_models_with_fp8_kv_cache( name_0="no_chunked_prefill", name_1="chunked_prefill", ) + + +@pytest.mark.parametrize("max_tokens", [16]) +@pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("chunk_size", [30, 32]) +@pytest.mark.parametrize("use_v2_block_manager", [False, True]) +# NOTE: Increasing this in this suite will fail CI because we currently cannot +# reset distributed env properly. Use a value > 1 just when you test. +@pytest.mark.parametrize("tensor_parallel_size", [1]) +def test_with_prefix_caching( + vllm_runner, + max_tokens: int, + enforce_eager: bool, + chunk_size: int, + use_v2_block_manager: bool, + tensor_parallel_size: int, +) -> None: + """ + Checks exact match decode with and without prefix caching + with chunked prefill enabled. + """ + model = "meta-llama/Llama-2-7b-chat-hf" + # The common prompt has 142 tokens with Llama-2 tokenizer. + common_prompt = "You are a helpful AI assistant " * 20 + unique_prompts = [ + "Question", # Warmup + "Question", # Fully cached + "Another question", # Partial cached + ] + full_prompts = [f"{common_prompt}\n{p}" for p in unique_prompts] + + max_num_batched_tokens = max_num_seqs = chunk_size + outputs = {} # type: ignore + check_result = True + for enable in (True, False): + with vllm_runner( + model, + dtype="half", + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=True, + enable_prefix_caching=enable, + tensor_parallel_size=tensor_parallel_size, + use_v2_block_manager=use_v2_block_manager, + enforce_eager=enforce_eager, + max_num_seqs=max_num_seqs, + ) as vllm_model: + # It should fail when prefix caching is enable and chunk + # size is not a multiple of block size (16). + should_fail = chunk_size % 16 != 0 and enable + check_result &= not should_fail + outputs[enable] = [] + # Send the request one-by-one to ensure the cache is populated. + with pytest.raises(ValueError) if should_fail else nullcontext(): + for prompt in full_prompts: + outputs[enable] += vllm_model.generate_greedy([prompt], + max_tokens) + + # Check results only if we did not expect a failure. + if check_result: + check_outputs_equal( + outputs_0_lst=outputs[False], + outputs_1_lst=outputs[True], + name_0="w/o prefix caching", + name_1="with prefix caching", + ) diff --git a/tests/compile/test_wrapper.py b/tests/compile/test_wrapper.py new file mode 100644 index 0000000000000..3668c1fab6b89 --- /dev/null +++ b/tests/compile/test_wrapper.py @@ -0,0 +1,59 @@ +from typing import Optional + +import torch + +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher + + +class MyMod(torch.nn.Module): + + def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + if cache is not None: + return x + cache + return x * 2 + + +class MyWrapper(TorchCompileWrapperWithCustomDispatcher): + + def __init__(self, model): + self.model = model + compiled_callable = torch.compile(self.forward, backend="eager") + super().__init__(compiled_callable) + + def forward(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + # this is the function to be compiled + return self.model(x, cache) + + def __call__(self, x: torch.Tensor, cache: Optional[torch.Tensor] = None): + # let torch.compile compile twice + if len(self.compiled_codes) == 2: + dispatch_id = 0 if cache is None else 1 + with self.dispatch_to_code(dispatch_id): + return self.forward(x, cache) + else: + return self.compiled_callable(x, cache) + + +def test_torch_compile_wrapper(): + mod = MyMod() + wrappers = [] + for i in range(3): + torch._dynamo.reset() + wrapper = MyWrapper(mod) + wrappers.append(wrapper) + x = torch.tensor([1]) + wrapper(x, None) # profile run, compile + # create a cache tensor + cache = torch.tensor([2]) + wrapper(x, cache) # warm up with cache, recompile + + # for new input, dispatch to the compiled code directly + new_x = torch.tensor([3]) + assert wrapper(new_x, + None).item() == 6 # dispatch to the first compiled code + assert wrapper( + new_x, cache).item() == 5 # dispatch to the second compiled code + + for wrapper in wrappers: + # make sure they have independent compiled codes + assert len(wrapper.compiled_codes) == 2 diff --git a/tests/conftest.py b/tests/conftest.py index d8264f65b6149..cd0091b7cba68 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -209,8 +209,14 @@ class HfRunner: def wrap_device(self, input: _T) -> _T: if not is_cpu(): + # Check if the input is already on the GPU + if hasattr(input, 'device') and input.device.type == "cuda": + return input # Already on GPU, no need to move return input.to("cuda") else: + # Check if the input is already on the CPU + if hasattr(input, 'device') and input.device.type == "cpu": + return input # Already on CPU, no need to move return input.to("cpu") def __init__( @@ -272,7 +278,7 @@ def __init__( def generate( self, prompts: List[str], - images: Optional[List[Image.Image]] = None, + images: Optional[PromptImageInput] = None, **kwargs: Any, ) -> List[Tuple[List[List[int]], List[str]]]: if images: @@ -308,7 +314,7 @@ def generate_greedy( self, prompts: List[str], max_tokens: int, - images: Optional[List[Image.Image]] = None, + images: Optional[PromptImageInput] = None, **kwargs: Any, ) -> List[Tuple[List[int], str]]: outputs = self.generate(prompts, @@ -345,7 +351,7 @@ def generate_greedy_logprobs( self, prompts: List[str], max_tokens: int, - images: Optional[List[Image.Image]] = None, + images: Optional[PromptImageInput] = None, **kwargs: Any, ) -> List[List[torch.Tensor]]: all_logprobs: List[List[torch.Tensor]] = [] @@ -427,8 +433,8 @@ def generate_greedy_logprobs_limit( prompts: List[str], max_tokens: int, num_logprobs: int, - images: Optional[List[Image.Image]] = None, - audios: Optional[List[Tuple[np.ndarray, int]]] = None, + images: Optional[PromptImageInput] = None, + audios: Optional[PromptAudioInput] = None, **kwargs: Any, ) -> List[Tuple[List[int], str, List[Dict[int, float]]]]: all_logprobs: List[List[Dict[int, float]]] = [] @@ -665,7 +671,7 @@ def generate_greedy( self, prompts: List[str], max_tokens: int, - images: Optional[List[Image.Image]] = None, + images: Optional[PromptImageInput] = None, ) -> List[Tuple[List[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) outputs = self.generate(prompts, greedy_params, images=images) diff --git a/tests/core/test_block_manager.py b/tests/core/test_block_manager.py index cd306b9e4d3cc..2ee9f20824f2f 100644 --- a/tests/core/test_block_manager.py +++ b/tests/core/test_block_manager.py @@ -595,3 +595,43 @@ def test_sliding_window_multi_seq(): # assert all blocks are free now assert block_manager.get_num_free_gpu_blocks() == num_gpu_blocks + + +def test_mark_blocks_as_computed_with_prefix_cache_and_chunked_prefill(): + """When prefix cache and chunked prefill are enabled, the block manager + should only mark a chunk of blocks as computed instead of all blocks. + """ + + block_size = 4 + num_cpu_blocks = 0 + num_gpu_blocks = 16 + block_manager = BlockSpaceManagerV1(block_size, + num_gpu_blocks, + num_cpu_blocks, + watermark=0, + enable_caching=True) + + # Set prompt size to have num_gpu_blocks - 1 full blocks. + prompt_length = block_size * num_gpu_blocks - 1 + + # Allocate (reserve) all blocks. + _, seq_group = create_dummy_prompt("0", + prompt_length, + block_size=block_size) + block_manager.allocate(seq_group) + assert seq_group.seqs[0].n_blocks == num_gpu_blocks + + # 1st chunk: Compute 2 and half blocks. Should mark 2 blocks as computed. + token_chunk_size = int(block_size * 2.5) + block_manager.mark_blocks_as_computed(seq_group, token_chunk_size) + computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0]) + assert len(computed_blocks) == 2 + + # Actual computed tokens. + seq_group.seqs[0].data.update_num_computed_tokens(token_chunk_size) + + # 2nd chunk: Complete 3rd block and additional 4 blocks. + token_chunk_size = int(block_size * 4.5) + block_manager.mark_blocks_as_computed(seq_group, token_chunk_size) + computed_blocks = block_manager.get_all_computed_blocks(seq_group.seqs[0]) + assert len(computed_blocks) == 7 diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index 6d9c2f3ebba4a..2f6ea632a5d9b 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -562,3 +562,42 @@ def test_chunked_prefill_max_seqs(): assert len(get_sequence_groups(out)) == max_seqs assert not running[0].is_prefill() assert not running[1].is_prefill() + + +def test_perfix_caching(): + """Verify allocating full blocks when prefix caching is enabled.""" + block_size = 4 + max_seqs = 10 + max_model_len = 80 + max_num_batched_tokens = 64 + scheduler_config = SchedulerConfig(max_num_batched_tokens, + max_seqs, + max_model_len, + enable_chunked_prefill=True) + cache_config = CacheConfig(block_size, + 1.0, + 1, + "auto", + enable_prefix_caching=True) + cache_config.num_cpu_blocks = 0 + cache_config.num_gpu_blocks = 32 + scheduler = Scheduler(scheduler_config, cache_config, None) + running: List[SequenceGroup] = [] + + # Add seq groups to scheduler. + for i in range(2): + _, seq_group = create_dummy_prompt(str(i), + block_size=block_size, + prompt_length=50) + scheduler.add_seq_group(seq_group) + running.append(seq_group) + + seq_group_meta, out = schedule_and_update_computed_tokens(scheduler) + assert set(get_sequence_groups(out)) == set(running) + assert seq_group_meta[0].token_chunk_size == 50 + # Verify it is chunked. Note that although the budget is 64-50=14, + # we only allocate full blocks for prefix caching, so only 4*(14//4)=12 + # tokens are allocated. + assert seq_group_meta[1].token_chunk_size == 12 + assert out.num_prefill_groups == 2 + assert out.num_batched_tokens == 62 diff --git a/tests/data/test_config.yaml b/tests/data/test_config.yaml new file mode 100644 index 0000000000000..20d499624de2e --- /dev/null +++ b/tests/data/test_config.yaml @@ -0,0 +1,2 @@ +port: 12312 +tensor_parallel_size: 2 diff --git a/tests/distributed/test_multimodal_broadcast.py b/tests/distributed/test_multimodal_broadcast.py index e7723a7ae2480..73ef863c2f193 100644 --- a/tests/distributed/test_multimodal_broadcast.py +++ b/tests/distributed/test_multimodal_broadcast.py @@ -35,9 +35,11 @@ def test_models(hf_runner, vllm_runner, image_assets, model: str, if model.startswith("llava-hf/llava-1.5"): from ..models.test_llava import models, run_test elif model.startswith("llava-hf/llava-v1.6"): - from ..models.test_llava_next import models, run_test + from ..models.test_llava_next import run_test # type: ignore[no-redef] + from ..models.test_llava_next import models elif model.startswith("facebook/chameleon"): - from ..models.test_chameleon import models, run_test + from ..models.test_chameleon import run_test # type: ignore[no-redef] + from ..models.test_chameleon import models else: raise NotImplementedError(f"Unsupported model: {model}") diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 4d54e43d5788c..637d2b30f6b1f 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -18,23 +18,26 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" -@pytest.mark.parametrize(("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, " - "MODEL_NAME, DIST_BACKEND"), - [ - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), - ]) +@pytest.mark.parametrize( + ("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, " + "MODEL_NAME, DIST_BACKEND"), + [ + (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), + (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), + (2, 2, 1, 1, 1, "internlm/internlm2_5-7b-chat", "ray"), + ], +) @fork_new_process_for_each_test -def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, - DIST_BACKEND): +def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, + TRUST_REMOTE_CODE, MODEL_NAME, DIST_BACKEND): if VLLM_MULTI_NODE and DIST_BACKEND == "mp": pytest.skip("Skipping multi-node pipeline parallel test for " "multiprocessing distributed backend") @@ -71,6 +74,9 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, if EAGER_MODE: pp_args.append("--enforce-eager") tp_args.append("--enforce-eager") + if TRUST_REMOTE_CODE: + pp_args.append("--trust-remote-code") + tp_args.append("--trust-remote-code") pp_env = None if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2 and CHUNKED_PREFILL): diff --git a/tests/engine/test_multiproc_workers.py b/tests/engine/test_multiproc_workers.py index 610ad9732fb91..e07dd6deef5bf 100644 --- a/tests/engine/test_multiproc_workers.py +++ b/tests/engine/test_multiproc_workers.py @@ -83,7 +83,7 @@ def execute_workers(worker_input: str) -> None: workers[3].process.kill() # Other workers should get shut down here - worker_monitor.join(2) + worker_monitor.join(20) # Ensure everything is stopped assert not worker_monitor.is_alive() @@ -108,7 +108,7 @@ def test_local_workers_clean_shutdown() -> None: # Clean shutdown worker_monitor.close() - worker_monitor.join(5) + worker_monitor.join(20) # Ensure everything is stopped assert not worker_monitor.is_alive() @@ -161,7 +161,7 @@ async def execute_workers(worker_input: str) -> None: workers[3].process.kill() # Other workers should get shut down here - worker_monitor.join(2) + worker_monitor.join(20) # Ensure everything is stopped assert not worker_monitor.is_alive() diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index c426e9b4ee899..ef34bebbb0f8c 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -6,6 +6,7 @@ from vllm import LLM, RequestOutput, SamplingParams from ...conftest import cleanup +from ..openai.test_vision import TEST_IMAGE_URLS MODEL_NAME = "facebook/opt-125m" @@ -159,3 +160,36 @@ def test_chat(): ] outputs = llm.chat(messages) assert len(outputs) == 1 + + +@pytest.mark.parametrize("image_urls", + [[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]]) +def test_chat_multi_image(image_urls: List[str]): + llm = LLM( + model="microsoft/Phi-3.5-vision-instruct", + dtype="bfloat16", + max_model_len=4096, + max_num_seqs=5, + enforce_eager=True, + trust_remote_code=True, + limit_mm_per_prompt={"image": 2}, + ) + + messages = [{ + "role": + "user", + "content": [ + *({ + "type": "image_url", + "image_url": { + "url": image_url + } + } for image_url in image_urls), + { + "type": "text", + "text": "What's in this image?" + }, + ], + }] + outputs = llm.chat(messages) + assert len(outputs) >= 0 diff --git a/tests/entrypoints/llm/test_generate_multiple_loras.py b/tests/entrypoints/llm/test_generate_multiple_loras.py index 35eabf079964a..9f5727ecd0406 100644 --- a/tests/entrypoints/llm/test_generate_multiple_loras.py +++ b/tests/entrypoints/llm/test_generate_multiple_loras.py @@ -50,7 +50,7 @@ def zephyr_lora_files(): @pytest.mark.skip_global_cleanup def test_multiple_lora_requests(llm: LLM, zephyr_lora_files): lora_request = [ - LoRARequest(LORA_NAME, idx + 1, zephyr_lora_files) + LoRARequest(LORA_NAME + str(idx), idx + 1, zephyr_lora_files) for idx in range(len(PROMPTS)) ] # Multiple SamplingParams should be matched with each prompt diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index 3783b7cd66a6a..c3a6c65be1d90 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from unittest.mock import MagicMock +from vllm.config import MultiModalConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.serving_chat import OpenAIServingChat @@ -20,6 +21,7 @@ class MockModelConfig: max_model_len = 100 tokenizer_revision = None embedding_mode = False + multimodal_config = MultiModalConfig() @dataclass diff --git a/tests/entrypoints/openai/test_serving_engine.py b/tests/entrypoints/openai/test_serving_engine.py new file mode 100644 index 0000000000000..325bc03434287 --- /dev/null +++ b/tests/entrypoints/openai/test_serving_engine.py @@ -0,0 +1,107 @@ +from http import HTTPStatus +from unittest.mock import MagicMock + +import pytest + +from vllm.config import ModelConfig +from vllm.engine.protocol import AsyncEngineClient +from vllm.entrypoints.openai.protocol import (ErrorResponse, + LoadLoraAdapterRequest, + UnloadLoraAdapterRequest) +from vllm.entrypoints.openai.serving_engine import OpenAIServing + +MODEL_NAME = "meta-llama/Llama-2-7b" +LORA_LOADING_SUCCESS_MESSAGE = ( + "Success: LoRA adapter '{lora_name}' added successfully.") +LORA_UNLOADING_SUCCESS_MESSAGE = ( + "Success: LoRA adapter '{lora_name}' removed successfully.") + + +async def _async_serving_engine_init(): + mock_engine_client = MagicMock(spec=AsyncEngineClient) + mock_model_config = MagicMock(spec=ModelConfig) + # Set the max_model_len attribute to avoid missing attribute + mock_model_config.max_model_len = 2048 + + serving_engine = OpenAIServing(mock_engine_client, + mock_model_config, + served_model_names=[MODEL_NAME], + lora_modules=None, + prompt_adapters=None, + request_logger=None) + return serving_engine + + +@pytest.mark.asyncio +async def test_load_lora_adapter_success(): + serving_engine = await _async_serving_engine_init() + request = LoadLoraAdapterRequest(lora_name="adapter", + lora_path="/path/to/adapter2") + response = await serving_engine.load_lora_adapter(request) + assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter') + assert len(serving_engine.lora_requests) == 1 + assert serving_engine.lora_requests[0].lora_name == "adapter" + + +@pytest.mark.asyncio +async def test_load_lora_adapter_missing_fields(): + serving_engine = await _async_serving_engine_init() + request = LoadLoraAdapterRequest(lora_name="", lora_path="") + response = await serving_engine.load_lora_adapter(request) + assert isinstance(response, ErrorResponse) + assert response.type == "InvalidUserInput" + assert response.code == HTTPStatus.BAD_REQUEST + + +@pytest.mark.asyncio +async def test_load_lora_adapter_duplicate(): + serving_engine = await _async_serving_engine_init() + request = LoadLoraAdapterRequest(lora_name="adapter1", + lora_path="/path/to/adapter1") + response = await serving_engine.load_lora_adapter(request) + assert response == LORA_LOADING_SUCCESS_MESSAGE.format( + lora_name='adapter1') + assert len(serving_engine.lora_requests) == 1 + + request = LoadLoraAdapterRequest(lora_name="adapter1", + lora_path="/path/to/adapter1") + response = await serving_engine.load_lora_adapter(request) + assert isinstance(response, ErrorResponse) + assert response.type == "InvalidUserInput" + assert response.code == HTTPStatus.BAD_REQUEST + assert len(serving_engine.lora_requests) == 1 + + +@pytest.mark.asyncio +async def test_unload_lora_adapter_success(): + serving_engine = await _async_serving_engine_init() + request = LoadLoraAdapterRequest(lora_name="adapter1", + lora_path="/path/to/adapter1") + response = await serving_engine.load_lora_adapter(request) + assert len(serving_engine.lora_requests) == 1 + + request = UnloadLoraAdapterRequest(lora_name="adapter1") + response = await serving_engine.unload_lora_adapter(request) + assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format( + lora_name='adapter1') + assert len(serving_engine.lora_requests) == 0 + + +@pytest.mark.asyncio +async def test_unload_lora_adapter_missing_fields(): + serving_engine = await _async_serving_engine_init() + request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None) + response = await serving_engine.unload_lora_adapter(request) + assert isinstance(response, ErrorResponse) + assert response.type == "InvalidUserInput" + assert response.code == HTTPStatus.BAD_REQUEST + + +@pytest.mark.asyncio +async def test_unload_lora_adapter_not_found(): + serving_engine = await _async_serving_engine_init() + request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter") + response = await serving_engine.unload_lora_adapter(request) + assert isinstance(response, ErrorResponse) + assert response.type == "InvalidUserInput" + assert response.code == HTTPStatus.BAD_REQUEST diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index d2ef3c2071efb..f61fa127b7d06 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -6,11 +6,10 @@ from vllm.multimodal.utils import encode_image_base64, fetch_image -from ...utils import VLLM_PATH, RemoteOpenAIServer +from ...utils import RemoteOpenAIServer -MODEL_NAME = "llava-hf/llava-1.5-7b-hf" -LLAVA_CHAT_TEMPLATE = VLLM_PATH / "examples/template_llava.jinja" -assert LLAVA_CHAT_TEMPLATE.exists() +MODEL_NAME = "microsoft/Phi-3.5-vision-instruct" +MAXIMUM_IMAGES = 2 # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) TEST_IMAGE_URLS = [ @@ -24,13 +23,9 @@ @pytest.fixture(scope="module") def server(): args = [ - "--dtype", - "bfloat16", - "--max-model-len", - "4096", - "--enforce-eager", - "--chat-template", - str(LLAVA_CHAT_TEMPLATE), + "--dtype", "bfloat16", "--max-model-len", "4096", "--max-num-seqs", + "5", "--enforce-eager", "--trust-remote-code", "--limit-mm-per-prompt", + f"image={MAXIMUM_IMAGES}" ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: @@ -84,7 +79,7 @@ async def test_single_chat_session_image(client: openai.AsyncOpenAI, choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=596, total_tokens=606) + completion_tokens=10, prompt_tokens=772, total_tokens=782) message = choice.message message = chat_completion.choices[0].message @@ -139,7 +134,7 @@ async def test_single_chat_session_image_base64encoded( choice = chat_completion.choices[0] assert choice.finish_reason == "length" assert chat_completion.usage == openai.types.CompletionUsage( - completion_tokens=10, prompt_tokens=596, total_tokens=606) + completion_tokens=10, prompt_tokens=772, total_tokens=782) message = choice.message message = chat_completion.choices[0].message @@ -217,26 +212,22 @@ async def test_chat_streaming_image(client: openai.AsyncOpenAI, @pytest.mark.asyncio @pytest.mark.parametrize("model_name", [MODEL_NAME]) -@pytest.mark.parametrize("image_url", TEST_IMAGE_URLS) +@pytest.mark.parametrize( + "image_urls", + [TEST_IMAGE_URLS[:i] for i in range(2, len(TEST_IMAGE_URLS))]) async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, - image_url: str): + image_urls: List[str]): messages = [{ "role": "user", "content": [ - { - "type": "image_url", - "image_url": { - "url": image_url - } - }, - { + *({ "type": "image_url", "image_url": { "url": image_url } - }, + } for image_url in image_urls), { "type": "text", "text": "What's in this image?" @@ -244,20 +235,30 @@ async def test_multi_image_input(client: openai.AsyncOpenAI, model_name: str, ], }] - with pytest.raises(openai.BadRequestError): # test multi-image input - await client.chat.completions.create( + if len(image_urls) > MAXIMUM_IMAGES: + with pytest.raises(openai.BadRequestError): # test multi-image input + await client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=10, + temperature=0.0, + ) + + # the server should still work afterwards + completion = await client.completions.create( + model=model_name, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + completion = completion.choices[0].text + assert completion is not None and len(completion) >= 0 + else: + chat_completion = await client.chat.completions.create( model=model_name, messages=messages, max_tokens=10, temperature=0.0, ) - - # the server should still work afterwards - completion = await client.completions.create( - model=model_name, - prompt=[0, 0, 0, 0, 0], - max_tokens=5, - temperature=0.0, - ) - completion = completion.choices[0].text - assert completion is not None and len(completion) >= 0 + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 0 diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py new file mode 100644 index 0000000000000..6ded5102c9314 --- /dev/null +++ b/tests/entrypoints/test_chat_utils.py @@ -0,0 +1,389 @@ +import warnings +from typing import Optional + +import pytest +from PIL import Image + +from vllm.assets.image import ImageAsset +from vllm.config import ModelConfig +from vllm.entrypoints.chat_utils import (parse_chat_messages, + parse_chat_messages_futures) +from vllm.multimodal import MultiModalDataDict +from vllm.multimodal.utils import encode_image_base64 +from vllm.transformers_utils.tokenizer_group import TokenizerGroup + +PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct" + + +@pytest.fixture(scope="module") +def phi3v_model_config(): + return ModelConfig(PHI3V_MODEL_ID, + PHI3V_MODEL_ID, + tokenizer_mode="auto", + trust_remote_code=True, + dtype="bfloat16", + seed=0, + limit_mm_per_prompt={ + "image": 2, + }) + + +@pytest.fixture(scope="module") +def phi3v_tokenizer(): + return TokenizerGroup( + tokenizer_id=PHI3V_MODEL_ID, + enable_lora=False, + max_num_seqs=5, + max_input_length=None, + ) + + +@pytest.fixture(scope="module") +def image_url(): + image = ImageAsset('cherry_blossom') + base64 = encode_image_base64(image.pil_image) + return f"data:image/jpeg;base64,{base64}" + + +def _assert_mm_data_is_image_input( + mm_data: Optional[MultiModalDataDict], + image_count: int, +) -> None: + assert mm_data is not None + assert set(mm_data.keys()) == {"image"} + + image_data = mm_data.get("image") + assert image_data is not None + + if image_count == 1: + assert isinstance(image_data, Image.Image) + else: + assert isinstance(image_data, list) and len(image_data) == image_count + + +def test_parse_chat_messages_single_image( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_data = parse_chat_messages([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in the image?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in the image?" + }] + _assert_mm_data_is_image_input(mm_data, 1) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_single_image_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_future = parse_chat_messages_futures([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in the image?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [{ + "role": "user", + "content": "<|image_1|>\nWhat's in the image?" + }] + _assert_mm_data_is_image_input(await mm_future, 1) + + +def test_parse_chat_messages_multiple_images( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_data = parse_chat_messages([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in these images?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_1|>\n<|image_2|>\nWhat's in these images?" + }] + _assert_mm_data_is_image_input(mm_data, 2) + + +@pytest.mark.asyncio +async def test_parse_chat_messages_multiple_images_async( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_future = parse_chat_messages_futures([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in these images?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_1|>\n<|image_2|>\nWhat's in these images?" + }] + _assert_mm_data_is_image_input(await mm_future, 2) + + +def test_parse_chat_messages_placeholder_already_in_prompt( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_data = parse_chat_messages([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": + "text", + "text": + "What's in <|image_1|> and how does it compare to <|image_2|>?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [{ + "role": + "user", + "content": + "What's in <|image_1|> and how does it compare to <|image_2|>?" + }] + _assert_mm_data_is_image_input(mm_data, 2) + + +def test_parse_chat_messages_placeholder_one_already_in_prompt( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_data = parse_chat_messages([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": + "text", + "text": + "What's in <|image_1|> and how does it compare to the other one?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [{ + "role": + "user", + "content": + "<|image_2|>\nWhat's in <|image_1|> and how does it compare to the " + "other one?" + }] + _assert_mm_data_is_image_input(mm_data, 2) + + +def test_parse_chat_messages_multiple_images_across_messages( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + conversation, mm_data = parse_chat_messages([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in this image?" + }] + }, { + "role": "assistant", + "content": "Some stuff." + }, { + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What about this one?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + assert conversation == [ + { + "role": "user", + "content": "<|image_1|>\nWhat's in this image?" + }, + { + "role": "assistant", + "content": "Some stuff." + }, + { + "role": "user", + "content": "<|image_2|>\nWhat about this one?" + }, + ] + _assert_mm_data_is_image_input(mm_data, 2) + + +def test_parse_chat_messages_rejects_too_many_images_in_one_message( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="coroutine 'async_get_and_parse_image' was never awaited") + with pytest.raises( + ValueError, + match="At most 2 image\\(s\\) may be provided in one request\\." + ): + parse_chat_messages([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in these images?" + }] + }], phi3v_model_config, phi3v_tokenizer) + + +def test_parse_chat_messages_rejects_too_many_images_across_messages( + phi3v_model_config, + phi3v_tokenizer, + image_url, +): + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", + message="coroutine 'async_get_and_parse_image' was never awaited") + with pytest.raises( + ValueError, + match="At most 2 image\\(s\\) may be provided in one request\\." + ): + parse_chat_messages([{ + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What's in this image?" + }] + }, { + "role": "assistant", + "content": "Some stuff." + }, { + "role": + "user", + "content": [{ + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "image_url", + "image_url": { + "url": image_url + } + }, { + "type": "text", + "text": "What about these two?" + }] + }], phi3v_model_config, phi3v_tokenizer) diff --git a/tests/kernels/test_awq_triton.py b/tests/kernels/test_awq_triton.py new file mode 100644 index 0000000000000..198d40a155ccb --- /dev/null +++ b/tests/kernels/test_awq_triton.py @@ -0,0 +1,169 @@ +"""Tests for the AWQ Triton kernel. + +Run `pytest tests/kernels/test_awq_triton.py`. +""" +import pytest +import torch + +from vllm.model_executor.layers.quantization.awq_triton import ( + AWQ_TRITON_SUPPORTED_GROUP_SIZES, awq_dequantize_triton, awq_gemm_triton) + +device = "cuda" + + +def reverse_awq_order(t: torch.Tensor): + bits = 4 + AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] + reverse_order_tensor = torch.arange( + t.shape[-1], + dtype=torch.int32, + device=t.device, + ) + reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits) + reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER] + reverse_order_tensor = reverse_order_tensor.view(-1) + + t = t[:, reverse_order_tensor] & 0xF + return t + + +# qweights - [R , C // 8], int32 +# scales - [R // G, C ], float16 +# zeros - [R // G, C // 8], int32 +def awq_dequantize_torch(qweight: torch.Tensor, scales: torch.Tensor, + qzeros: torch.Tensor, + group_size: int) -> torch.Tensor: + + if group_size == -1: + group_size = qweight.shape[0] + + bits = 4 + shifts = torch.arange(0, 32, bits, device=qzeros.device) + + iweights = torch.bitwise_right_shift(qweight[:, :, None], + shifts[None, None, :]).to(torch.int8) + + iweights = iweights.view(iweights.shape[0], -1) + + zeros = torch.bitwise_right_shift(qzeros[:, :, None], + shifts[None, None, :]).to(torch.int8) + zeros = zeros.view(qzeros.shape[0], -1) + zeros = reverse_awq_order(zeros) + + iweights = reverse_awq_order(iweights) + + iweights = torch.bitwise_and(iweights, (2**bits) - 1) + zeros = torch.bitwise_and(zeros, (2**bits) - 1) + + scales = scales.repeat_interleave(group_size, dim=0) + zeros = zeros.repeat_interleave(group_size, dim=0) + return (iweights - zeros) * scales + + +# qweights - [R , C // 8], int32 +# scales - [R // G, C ], float16 +# zeros - [R // G, C // 8], int32 +@pytest.mark.parametrize("qweight_rows", [3584, 18944, 128, 256, 512, 1024]) +@pytest.mark.parametrize("qweight_cols", [448, 576, 4736, 16, 32, 64, 128]) +@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES) +def test_dequantize(qweight_rows, qweight_cols, group_size): + + if group_size == -1: + group_size = qweight_rows + + qweight_dtype = torch.int32 + scales_rows = qweight_rows // group_size + scales_cols = qweight_cols * 8 + scales_dtype = torch.float16 + zeros_rows = scales_rows + zeros_cols = qweight_cols + zeros_dtype = torch.int32 + + torch.manual_seed(0) + + qweight = torch.randint(0, + torch.iinfo(torch.int32).max, + (qweight_rows, qweight_cols), + dtype=qweight_dtype, + device=device) + scales = torch.rand(scales_rows, + scales_cols, + dtype=scales_dtype, + device=device) + zeros = torch.randint(0, + torch.iinfo(torch.int32).max, + (zeros_rows, zeros_cols), + dtype=zeros_dtype, + device=device) + + iweights_triton = awq_dequantize_triton(qweight, scales, zeros) + + assert (not torch.any(torch.isinf(iweights_triton)) + and not torch.any(torch.isnan(iweights_triton))) + + iweights_torch = awq_dequantize_torch(qweight, scales, zeros, group_size) + + torch.testing.assert_close(iweights_triton, iweights_torch) + + +# input - [N, K] +# qweight - [K, M // 8] +# qzeros - [K // G, M // 8] +# scales - [K // G, M] +@pytest.mark.parametrize("N", [1, 2, 4, 8, 14, 17, 23, 32]) +@pytest.mark.parametrize("K", [128]) +@pytest.mark.parametrize("M", [16, 24, 32]) +@pytest.mark.parametrize("group_size", AWQ_TRITON_SUPPORTED_GROUP_SIZES) +@pytest.mark.parametrize("splitK", [1, 8]) +def test_gemm(N, K, M, splitK, group_size): + + if group_size == -1: + group_size = K + + split_k_iters = splitK + + input_rows = N + input_cols = K + input_dtype = torch.float32 + qweight_rows = input_cols + qweight_cols = M // 8 + scales_rows = qweight_rows // group_size + scales_cols = M + scales_dtype = torch.float32 + qzeros_rows = scales_rows + qzeros_cols = qweight_cols + + torch.manual_seed(0) + + input = torch.rand((input_rows, input_cols), + dtype=input_dtype, + device=device) + qweight = torch.randint(0, + torch.iinfo(torch.int32).max, + (qweight_rows, qweight_cols), + device=device) + qzeros = torch.randint(0, + torch.iinfo(torch.int32).max, + (qzeros_rows, qzeros_cols), + device=device) + scales = torch.rand((scales_rows, scales_cols), + dtype=scales_dtype, + device=device) + + output_triton = awq_gemm_triton(input, qweight, scales, qzeros, + split_k_iters) + + assert (not torch.any(torch.isinf(output_triton)) + and not torch.any(torch.isnan(output_triton))) + + dequantized_weights = awq_dequantize_triton(qweight, scales, qzeros) + + output_torch = torch.matmul(input, dequantized_weights) + + assert (not torch.any(torch.isinf(output_torch)) + and not torch.any(torch.isnan(output_torch))) + + torch.testing.assert_close(output_triton.cpu(), + output_torch.cpu(), + atol=1e-1, + rtol=1e-1) diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py new file mode 100644 index 0000000000000..7bf338b36953a --- /dev/null +++ b/tests/kernels/test_causal_conv1d.py @@ -0,0 +1,205 @@ +from typing import Optional + +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange + +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) + + +def causal_conv1d_ref( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, + weight.unsqueeze(1), + bias, + padding=width - 1, + groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return (out, None) if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update_ref(x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = None): + """ + x: (batch, dim) + conv_state: (batch, dim, width) + weight: (dim, width) + bias: (dim,) + + out: (batch, dim) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + batch, dim = x.shape + width = weight.shape[1] + assert conv_state.shape == (batch, dim, width) + assert weight.shape == (dim, width) + conv_state.copy_(torch.roll(conv_state, shifts=-1, + dims=-1)) # Update state (B D W) + conv_state[:, :, -1] = x + out = torch.sum(conv_state * weight, dim=-1) # (B D) + if bias is not None: + out += bias + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) + + +@pytest.mark.parametrize("return_final_states", [False, True]) +@pytest.mark.parametrize("has_initial_states", [False, True]) +@pytest.mark.parametrize("channel_last", [False, True]) +@pytest.mark.parametrize("itype", [torch.bfloat16]) +@pytest.mark.parametrize("silu_activation", [False, True]) +@pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize("width", [4]) +@pytest.mark.parametrize("seqlen", [128, 512, 4096]) +@pytest.mark.parametrize('dim', [64, 4096 + 32]) +@pytest.mark.parametrize('batch', [1, 2]) +def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, + itype, channel_last, has_initial_states, + return_final_states): + if not channel_last and (has_initial_states or return_final_states): + pytest.skip( + "Only channel_last support initial_states or return_final_states") + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + # set seed + torch.random.manual_seed(0) + if not channel_last: + x = torch.randn(batch, + 4096 + dim + 64, + seqlen, + device=device, + dtype=itype)[:, 4096:4096 + dim, :] + else: + x = rearrange( + torch.randn(batch, + seqlen, + 4096 + dim + 64, + device=device, + dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s") + weight = torch.randn(dim, width, device=device, dtype=itype) + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None + if has_initial_states: + initial_states = torch.randn(batch, + width - 1, + dim, + device=device, + dtype=itype).transpose(1, 2) + else: + initial_states = None + x_ref = x.detach().clone() + weight_ref = weight.detach().clone() + bias_ref = bias.detach().clone() if bias is not None else None + initial_states_ref = initial_states.detach().clone( + ) if initial_states is not None else None + activation = None if not silu_activation else "silu" + out, final_states = causal_conv1d_fn( + x, + weight, + bias, + initial_states=initial_states, + return_final_states=return_final_states, + activation=activation) + out_ref, final_states_ref = causal_conv1d_ref( + x_ref, + weight_ref, + bias_ref, + initial_states=initial_states_ref, + return_final_states=return_final_states, + activation=activation) + if return_final_states: + assert final_states is not None and final_states_ref is not None + assert torch.allclose(final_states, + final_states_ref, + rtol=rtol, + atol=atol) + + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + if return_final_states: + out += F.sigmoid(final_states).sum(dim=-1, keepdim=True) + out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True) + + +@pytest.mark.parametrize("itype", [torch.bfloat16]) +@pytest.mark.parametrize("silu_activation", [False, True]) +@pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize("width", [2, 3, 4]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +@pytest.mark.parametrize("batch", [1, 2]) +def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation, + itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + # set seed + torch.random.manual_seed(0) + batch = 2 + x = torch.randn(batch, dim, device=device, dtype=itype) + conv_state = torch.randn(batch, dim, width, device=device, dtype=itype) + weight = torch.randn(dim, + width, + device=device, + dtype=itype, + requires_grad=True) + if has_bias: + bias = torch.randn(dim, device=device, dtype=itype, requires_grad=True) + else: + bias = None + conv_state_ref = conv_state.detach().clone() + activation = None if not silu_activation else "silu" + out = causal_conv1d_update(x, + conv_state, + weight, + bias, + activation=activation) + out_ref = causal_conv1d_update_ref(x, + conv_state_ref, + weight, + bias, + activation=activation) + + assert torch.equal(conv_state, conv_state_ref) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/tests/kernels/test_flashinfer.py b/tests/kernels/test_flashinfer.py index f109792ad251b..696cc0c6cdf10 100644 --- a/tests/kernels/test_flashinfer.py +++ b/tests/kernels/test_flashinfer.py @@ -73,11 +73,14 @@ def ref_paged_attn( @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) @torch.inference_mode -def test_flashinfer_decode_with_paged_kv(kv_lens: List[int], - num_heads: Tuple[int, - int], head_size: int, - dtype: torch.dtype, block_size: int, - soft_cap: Optional[float]) -> None: +def test_flashinfer_decode_with_paged_kv( + kv_lens: List[int], + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float], +) -> None: torch.set_default_device("cuda") torch.cuda.manual_seed_all(0) num_seqs = len(kv_lens) @@ -88,6 +91,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int], scale = head_size**-0.5 query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + key_value_cache = torch.randn(NUM_BLOCKS, 2, block_size, @@ -125,7 +129,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int], wrapper = flashinfer.\ BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", use_tensor_cores=( - (num_query_heads//num_kv_heads) not in (1, 2, 4, 8)) + (num_query_heads//num_kv_heads) > 4) ) wrapper.begin_forward(kv_indptr, kv_indices, @@ -249,3 +253,216 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]], soft_cap=soft_cap) torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \ f"{torch.max(torch.abs(output - ref_output))}" + + +@pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]]) +@pytest.mark.parametrize("num_heads", [(32, 8), (6, 1)]) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) +def test_flashinfer_prefill_with_paged_fp8_kv( + seq_lens: List[Tuple[int, int]], num_heads: Tuple[int, int], + head_size: int, dtype: torch.dtype, block_size: int, + soft_cap: Optional[float]) -> None: + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + num_seqs = len(seq_lens) + query_lens = [x[0] for x in seq_lens] + kv_lens = [x[1] for x in seq_lens] + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_kv_len = max(kv_lens) + scale = head_size**-0.5 + + kv_cache_dtype = torch.float8_e4m3fn + + query = torch.randn(sum(query_lens), + num_query_heads, + head_size, + dtype=dtype) + NUM_BLOCKS_FP8 = 2048 + key_value_cache = torch.randn(NUM_BLOCKS_FP8, + 2, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1) + key_cache /= head_size**0.5 + value_cache /= head_size**0.5 + + k_scale = key_cache.amax().item() / 448.0 + v_scale = value_cache.amax().item() / 448.0 + + kv_cache_fp8 = torch.cat([key_cache / k_scale, value_cache / v_scale], + dim=1).to(kv_cache_dtype) + + assert (kv_cache_fp8.shape == key_value_cache.shape) + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS_FP8, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + qo_indptr = [0] + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(num_seqs): + seq_len = kv_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) + qo_indptr.append(qo_indptr[-1] + query_lens[i]) + + qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32) + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, "NHD") + wrapper.begin_forward( + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + ) + + output = wrapper.forward(query, + kv_cache_fp8, + logits_soft_cap=soft_cap, + k_scale=k_scale, + v_scale=v_scale) + + ref_output = ref_paged_attn(query=query, + key_cache=key_cache.squeeze(1), + value_cache=value_cache.squeeze(1), + query_lens=query_lens, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap) + del query + del block_tables + # verify prefill fp8 + torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" + + +@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]]) +@pytest.mark.parametrize("num_heads", [(32, 8), (64, 8), (6, 1)]) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0]) +@torch.inference_mode +def test_flashinfer_decode_with_paged_fp8_kv( + kv_lens: List[int], + num_heads: Tuple[int, int], + head_size: int, + dtype: torch.dtype, + block_size: int, + soft_cap: Optional[float], +) -> None: + # test doesn't work for num_heads = (16,16) + torch.set_default_device("cuda") + torch.cuda.manual_seed_all(0) + num_seqs = len(kv_lens) + num_query_heads = num_heads[0] + num_kv_heads = num_heads[1] + assert num_query_heads % num_kv_heads == 0 + max_kv_len = max(kv_lens) + scale = head_size**-0.5 + use_tensor_cores = (num_query_heads // num_kv_heads) > 4 + kv_cache_dtype = torch.float8_e4m3fn + + query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype) + NUM_BLOCKS_FP8 = 2048 + key_value_cache = torch.randn(NUM_BLOCKS_FP8, + 2, + block_size, + num_kv_heads, + head_size, + dtype=dtype) + key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1) + key_cache /= head_size**0.5 + value_cache /= head_size**0.5 + + k_scale = key_cache.amax().item() / 448.0 + v_scale = value_cache.amax().item() / 448.0 + + key_cache_fp8 = (key_cache / k_scale).to(kv_cache_dtype) + value_cache_fp8 = (value_cache / v_scale).to(kv_cache_dtype) + assert (key_cache_fp8.shape[1] == 1 and value_cache_fp8.shape[1] == 1) + kv_cache_fp8 = torch.cat([key_cache_fp8, value_cache_fp8], dim=1) + + max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size + block_tables = torch.randint(0, + NUM_BLOCKS_FP8, + (num_seqs, max_num_blocks_per_seq), + dtype=torch.int32) + + kv_indptr = [0] + kv_indices = [] + kv_last_page_lens = [] + for i in range(num_seqs): + seq_len = kv_lens[i] + assert seq_len > 0 + num_blocks = (seq_len + block_size - 1) // block_size + kv_indices.extend(block_tables[i, :num_blocks]) + kv_indptr.append(kv_indptr[-1] + num_blocks) + kv_last_page_len = seq_len % block_size + if kv_last_page_len == 0: + kv_last_page_len = block_size + kv_last_page_lens.append(kv_last_page_len) + + kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32) + kv_indices = torch.tensor(kv_indices, dtype=torch.int32) + kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32) + + workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8) + wrapper = flashinfer.\ + BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD", + use_tensor_cores=use_tensor_cores) + wrapper.begin_forward(kv_indptr, + kv_indices, + kv_last_page_lens, + num_query_heads, + num_kv_heads, + head_size, + block_size, + "NONE", + data_type=dtype, + q_data_type=dtype) + output = wrapper.forward(query, + kv_cache_fp8, + logits_soft_cap=soft_cap, + k_scale=k_scale, + v_scale=v_scale) + key_cache = key_value_cache[:, 0, :, :, :].squeeze(1) + value_cache = key_value_cache[:, 1, :, :, :].squeeze(1) + + ref_output = ref_paged_attn(query=query, + key_cache=key_cache, + value_cache=value_cache, + query_lens=[1] * num_seqs, + kv_lens=kv_lens, + block_tables=block_tables, + scale=scale, + soft_cap=soft_cap) + # Temporary fix: Increasing the tolerance. Seems like a flashinfer issue + torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \ + f"{torch.max(torch.abs(output - ref_output))}" diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py new file mode 100644 index 0000000000000..d3cb0a8656a02 --- /dev/null +++ b/tests/kernels/test_mamba_ssm.py @@ -0,0 +1,324 @@ +import pytest +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) + + +def selective_state_update_ref(state, + x, + dt, + A, + B, + C, + D=None, + z=None, + dt_bias=None, + dt_softplus=False): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.dim() > 3 + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + batch, nheads, dim, dstate = state.shape + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + dt = dt + dt_bias + dt = F.softplus(dt) if dt_softplus else dt + dA = torch.exp(rearrange(dt, "b h d -> b h d 1") * + A) # (batch, nheads, dim, dstate) + B = repeat(B, "b g n -> b (g h) n", + h=nheads // ngroups) # (batch, nheads, dstate) + C = repeat(C, "b g n -> b (g h) n", + h=nheads // ngroups) # (batch, nheads, dstate) + dB = rearrange(dt, "b h d -> b h d 1") * rearrange( + B, "b h n -> b h 1 n") # (batch, nheads, dim, dstate) + state.copy_(state * dA + + dB * rearrange(x, "b h d -> b h d 1")) # (batch, dim, dstate + out = torch.einsum("bhdn,bhn->bhd", state.to(C.dtype), C) + if D is not None: + out += (x * D).to(out.dtype) + out = (out if z is None else out * F.silu(z)).to(x.dtype) + if not has_heads: + out = out.squeeze(1) + return out + + +def selective_scan_ref(u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, + position_indices=None, + prev_state=None): + """ + u: r(B D L) + delta: r(B D L) + A: c(D N) or r(D N) + B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L) + D: r(D) + z: r(B D L) + delta_bias: r(D), fp32 + prev_state: r(B D N), fp32 + + out: r(B D L) + last_state (optional): r(B D dstate) or c(B D dstate) + """ + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + is_variable_B = B.dim() >= 3 + is_variable_C = C.dim() >= 3 + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state + ys = [] + deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A)) + if not is_variable_B: + deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u) + else: + if B.dim() == 3: + deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u) + else: + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) + if is_variable_C and C.dim() == 4: + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + for i in range(u.shape[2]): + if position_indices is not None and position_indices[0, i] == 0: + x = deltaB_u[:, :, i] + else: + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if not is_variable_C: + y = torch.einsum('bdn,dn->bd', x, C) + else: + if C.dim() == 3: + y = torch.einsum('bdn,bn->bd', x, C[:, :, i]) + else: + y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + ys.append(y) + y = torch.stack(ys, dim=2) # (batch dim L) + out = y if D is None else y + u * rearrange(D, "d -> d 1") + if z is not None: + out = out * F.silu(z) + out = out.to(dtype=dtype_in) + return out if not return_last_state else (out, last_state) + + +@pytest.mark.parametrize('wtype', [torch.float32]) +@pytest.mark.parametrize('itype', [torch.float32]) +@pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096]) +@pytest.mark.parametrize("return_last_state", [True]) +@pytest.mark.parametrize('has_delta_bias', [True]) +@pytest.mark.parametrize('delta_softplus', [True]) +@pytest.mark.parametrize('has_z', [True]) +@pytest.mark.parametrize('has_D', [True]) +@pytest.mark.parametrize("varBC_groups", [1, 2]) +@pytest.mark.parametrize("is_variable_C", [True]) +@pytest.mark.parametrize("is_variable_B", [True]) +@pytest.mark.parametrize("scan_chunks", [1, 2, 3]) +def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, + has_z, has_delta_bias, delta_softplus, + return_last_state, seqlen, itype, wtype, scan_chunks): + if varBC_groups > 1 and (not is_variable_B or not is_variable_C): + pytest.skip() # This config is not applicable + device = 'cuda' + rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 3e-2, 5e-2 + rtolw, atolw = (1e-3, 1e-3) + if has_z: # If we have z, the errors on the weights seem higher + rtolw = max(rtolw, rtol) + atolw = max(atolw, atol) + # set seed + torch.random.manual_seed(0) + batch_size = 2 + dim = 4 + dstate = 8 + A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)) + if not is_variable_B: + B_shape = [dim, dstate] + elif varBC_groups == 1: + B_shape = [batch_size, dstate, seqlen] + else: + B_shape = [batch_size, varBC_groups, dstate, seqlen] + B = torch.randn(B_shape, + device=device, + dtype=wtype if not is_variable_B else itype) + if not is_variable_C: + C_shape = [dim, dstate] + elif varBC_groups == 1: + C_shape = [batch_size, dstate, seqlen] + else: + C_shape = [batch_size, varBC_groups, dstate, seqlen] + C = torch.randn(C_shape, + device=device, + dtype=wtype if not is_variable_C else itype) + D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None + z = torch.randn(batch_size, dim, seqlen, device=device, + dtype=itype) if has_z else None + delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32) + ) if has_delta_bias else None + u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) + delta = (0.5 * + torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)) + state = None + state_ref = None + out = None + out_ref = None + outs = [] + for c in range(scan_chunks): + chunked_prompt_len = seqlen // scan_chunks + chunk_start = chunked_prompt_len * c + chunk_end = chunked_prompt_len * (c + 1) + if c == scan_chunks - 1: + chunk_end = seqlen + _B = B + if is_variable_B: + _B = B[..., chunk_start:chunk_end] + _C = C + if is_variable_B: + _C = C[..., chunk_start:chunk_end] + _z = z + if has_z: + assert z is not None + _z = z[..., chunk_start:chunk_end] + out, *rest = selective_scan_fn(u[..., chunk_start:chunk_end], + delta[..., chunk_start:chunk_end], + A, + _B, + _C, + D, + z=_z, + delta_bias=delta_bias, + delta_softplus=delta_softplus, + return_last_state=return_last_state, + prev_state=state if c > 0 else None) + outs.append(out) + if return_last_state: + state = rest[0] + if len(outs) > 1: + out = torch.cat(outs, dim=-1) + out_ref, *rest = selective_scan_ref(u, + delta, + A, + B, + C, + D, + z=z, + delta_bias=delta_bias, + delta_softplus=delta_softplus, + return_last_state=return_last_state) + if return_last_state: + state_ref = rest[0] + + assert out is not None and out_ref is not None + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + if return_last_state: + assert state is not None and state_ref is not None + assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + + +@pytest.mark.parametrize("itype", + [torch.float32, torch.float16, torch.bfloat16]) +@pytest.mark.parametrize("has_z", [False, True]) +@pytest.mark.parametrize("dstate", [16, 32, 64]) +@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) +def test_selective_state_update(dim, dstate, has_z, itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + if torch.version.hip: + atol *= 2 + # set seed + torch.random.manual_seed(0) + batch_size = 1 + state = torch.randn(batch_size, dim, dstate, dtype=itype, device=device) + x = torch.randn(batch_size, dim, device=device, dtype=itype) + dt = torch.randn(batch_size, dim, device=device, dtype=itype) + dt_bias = torch.rand(dim, device=device) - 4.0 + A = -torch.rand(dim, dstate, device=device) - 1.0 + B = torch.randn(batch_size, dstate, device=device) + C = torch.randn(batch_size, dstate, device=device) + D = torch.randn(dim, device=device) + z = torch.randn_like(x) if has_z else None + state_ref = state.detach().clone() + out = selective_state_update(state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True) + out_ref = selective_state_update_ref(state_ref, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True) + + assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index f526c381b3339..2250cf1598b8b 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -2,6 +2,8 @@ Run `pytest tests/kernels/test_moe.py`. """ +from typing import List + import pytest import torch from transformers import MixtralConfig @@ -9,7 +11,13 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe +from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe, single_marlin_moe) +from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk +from vllm.model_executor.layers.quantization.utils.marlin_utils_test import ( + marlin_quantize) from vllm.model_executor.models.mixtral import MixtralMoE +from vllm.scalar_type import scalar_types def torch_moe(a, w1, w2, score, topk): @@ -29,6 +37,20 @@ def torch_moe(a, w1, w2, score, topk): topk_weight.view(B, -1, 1).to(out.dtype)).sum(dim=1) +def torch_moe_single(a, w, score, topk): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + _, topk_ids = torch.topk(score, topk) + topk_ids = topk_ids.view(-1) + for i in range(w.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = a[mask] @ w[i].transpose(0, 1) + return (out.view(B, -1, w.shape[1])).sum(dim=1) + + @pytest.mark.parametrize("m", [1024 * 128, 512, 222, 33, 1]) @pytest.mark.parametrize("n", [2048, 256, 1024]) @pytest.mark.parametrize("k", [128, 511, 1024]) @@ -43,11 +65,11 @@ def test_fused_moe( topk: int, dtype: torch.dtype, ): - a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device='cuda', dtype=dtype) + score = torch.randn((m, e), device="cuda", dtype=dtype) triton_output = fused_moe(a, w1, w2, score, topk, renormalize=False) torch_output = torch_moe(a, w1, w2, score, topk) torch.testing.assert_close(triton_output, torch_output, atol=1e-2, rtol=0) @@ -99,3 +121,194 @@ def test_mixtral_moe(dtype: torch.dtype): vllm_states, rtol=mixtral_moe_tol[dtype], atol=mixtral_moe_tol[dtype]) + + +def stack_and_dev(tensors: List[torch.Tensor]): + dev = tensors[0].device + return torch.stack(tensors, dim=0).to(dev) + + +def compute_max_diff(output, output_ref): + return torch.mean(torch.abs(output - output_ref)) / torch.mean( + torch.abs(output_ref)) + + +@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) +@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 1024, 512]) +@pytest.mark.parametrize("e", [4, 8, 64]) +@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +@pytest.mark.parametrize("act_order", [True, False]) +def test_fused_marlin_moe( + m: int, + n: int, + k: int, + e: int, + topk: int, + group_size: int, + act_order: bool, +): + torch.manual_seed(7) + + if topk > e: + return + + # Filter act_order + if act_order: + if group_size == -1: + return + if group_size in (k, n): + return + + quant_type = scalar_types.uint4b8 + dtype = torch.float16 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + for i in range(w2.shape[0]): + w2[0] = torch.eye(k, n, device="cuda", dtype=dtype) + + w_ref1_l = [] + qweight1_l = [] + scales1_l = [] + g_idx1_l = [] + sort_indices1_l = [] + + for i in range(w1.shape[0]): + test_perm = torch.randperm(k) + w_ref1, qweight1, scales1, g_idx1, sort_indices1, _ = marlin_quantize( + w1[i].transpose(1, 0), quant_type, group_size, act_order, + test_perm) + w_ref1_l.append(w_ref1) + qweight1_l.append(qweight1) + scales1_l.append(scales1) + g_idx1_l.append(g_idx1) + sort_indices1_l.append(sort_indices1) + + w_ref1 = stack_and_dev(w_ref1_l) + qweight1 = stack_and_dev(qweight1_l).contiguous() + scales1 = stack_and_dev(scales1_l) + g_idx1 = stack_and_dev(g_idx1_l) + sort_indices1 = stack_and_dev(sort_indices1_l) + + w_ref2_l = [] + qweight2_l = [] + scales2_l = [] + g_idx2_l = [] + sort_indices2_l = [] + + for i in range(w2.shape[0]): + test_perm = torch.randperm(n) + w_ref2, qweight2, scales2, g_idx2, sort_indices2, _ = marlin_quantize( + w2[i].transpose(1, 0), quant_type, group_size, act_order, + test_perm) + w_ref2_l.append(w_ref2) + qweight2_l.append(qweight2) + scales2_l.append(scales2) + g_idx2_l.append(g_idx2) + sort_indices2_l.append(sort_indices2) + + w_ref2 = stack_and_dev(w_ref2_l) + qweight2 = stack_and_dev(qweight2_l).contiguous() + scales2 = stack_and_dev(scales2_l) + g_idx2 = stack_and_dev(g_idx2_l) + sort_indices2 = stack_and_dev(sort_indices2_l) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + + topk_weights, topk_ids = fused_topk(a, score, topk, False) + + triton_output = fused_moe( + a, + w_ref1.transpose(1, 2).contiguous(), + w_ref2.transpose(1, 2).contiguous(), + score, + topk, + renormalize=False, + ) + marlin_output = fused_marlin_moe( + a, + qweight1, + qweight2, + score, + g_idx1, + g_idx2, + sort_indices1, + sort_indices2, + topk_weights, + topk_ids, + w1_scale=scales1, + w2_scale=scales2, + ) + + assert compute_max_diff(marlin_output, triton_output) < 4e-2 + + +@pytest.mark.skip("This test is here for the sake of debugging, " + "don't run it in automated tests.") +@pytest.mark.parametrize("m", [64, 512, 222, 33, 1]) +@pytest.mark.parametrize("n", [128, 2048, 256, 1024]) +@pytest.mark.parametrize("k", [128, 1024, 512]) +@pytest.mark.parametrize("e", [4, 8, 64]) +@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) +@pytest.mark.parametrize("act_order", [True, False]) +def test_marlin_moe_mmm( + m: int, + n: int, + k: int, + e: int, + topk: int, + group_size: int, + act_order: bool, +): + if topk > e: + return + + # Filter act_order + if act_order: + if group_size == -1: + return + if group_size == k: + return + + quant_type = scalar_types.uint4b8 + dtype = torch.float16 + a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 + w = torch.randn((e, n, k), device="cuda", dtype=dtype) / 10 + + w_ref_l = [] + qweights_l = [] + scales_l = [] + g_idx_l = [] + sort_indices_l = [] + + for i in range(w.shape[0]): + test_perm = torch.randperm(k) + w_ref, qweight, scales, g_idx, sort_indices, _ = marlin_quantize( + w[i].transpose(1, 0), quant_type, group_size, act_order, test_perm) + w_ref_l.append(w_ref) + qweights_l.append(qweight) + scales_l.append(scales) + g_idx_l.append(g_idx) + sort_indices_l.append(sort_indices) + + w_ref = stack_and_dev(w_ref_l) + qweight = stack_and_dev(qweights_l).contiguous() + scales = stack_and_dev(scales_l) + g_idx = stack_and_dev(g_idx_l) + sort_indices = stack_and_dev(sort_indices_l) + + score = torch.randn((m, e), device="cuda", dtype=dtype) + marlin_output = single_marlin_moe(a, + qweight, + scales, + score, + g_idx, + sort_indices, + topk, + renormalize=False) + torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) + + assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/tests/lora/test_gemma.py b/tests/lora/test_gemma.py index 709246179bfe4..f7c1d4f041c12 100644 --- a/tests/lora/test_gemma.py +++ b/tests/lora/test_gemma.py @@ -1,7 +1,10 @@ from typing import List +import pytest + import vllm from vllm.lora.request import LoRARequest +from vllm.utils import is_hip MODEL_PATH = "google/gemma-7b" @@ -10,7 +13,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: prompts = [ "Quote: Imagination is", "Quote: Be yourself;", - "Quote: So many books,", + "Quote: Painting is poetry that is seen rather than felt,", ] sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32) outputs = llm.generate( @@ -28,6 +31,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.mark.xfail(is_hip(), reason="There can be output mismatch on ROCm") def test_gemma_lora(gemma_lora_files): llm = vllm.LLM(MODEL_PATH, max_model_len=1024, @@ -37,7 +41,8 @@ def test_gemma_lora(gemma_lora_files): expected_lora_output = [ "more important than knowledge.\nAuthor: Albert Einstein\n", "everyone else is already taken.\nAuthor: Oscar Wilde\n", - "so little time\nAuthor: Frank Zappa\n", + "and poetry is painting that is felt rather than seen.\n" + "Author: Leonardo da Vinci\n", ] output1 = do_sample(llm, gemma_lora_files, lora_id=1) diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index 2370c693e9534..133e0d4514a6d 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -7,6 +7,7 @@ import vllm from vllm.lora.request import LoRARequest +from vllm.utils import is_hip from .conftest import cleanup @@ -17,12 +18,23 @@ class ModelWithQuantization: quantization: str -MODELS: List[ModelWithQuantization] = [ - ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", - quantization="AWQ"), - ModelWithQuantization(model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", - quantization="GPTQ"), -] +MODELS: List[ModelWithQuantization] +#AWQ quantization is currently not supported in ROCm. +if is_hip(): + MODELS = [ + ModelWithQuantization( + model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", + quantization="GPTQ"), + ] +else: + MODELS = [ + ModelWithQuantization( + model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", + quantization="AWQ"), + ModelWithQuantization( + model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", + quantization="GPTQ"), + ] def do_sample(llm: vllm.LLM, diff --git a/tests/models/test_fp8.py b/tests/models/test_fp8.py index 4ab968c01da04..17acdb52322fd 100644 --- a/tests/models/test_fp8.py +++ b/tests/models/test_fp8.py @@ -3,116 +3,97 @@ Note: these tests will only pass on L4 GPU. """ import os -from typing import List +from typing import Optional import pytest -import torch -from transformers import AutoTokenizer +from tests.kernels.utils import override_backend_env_variable from tests.quantization.utils import is_quant_method_supported -from vllm import LLM, SamplingParams -os.environ["TOKENIZERS_PARALLELISM"] = "true" - -MAX_MODEL_LEN = 1024 - -MODELS = [ - "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV", - "meta-llama/Meta-Llama-3-8B-Instruct", -] +from ..models.utils import check_logprobs_close -EXPECTED_STRS_MAP = { - "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV": { - "auto": [ - 'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (', - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', - 'Artificial intelligence (AI) and human intelligence (HI) process information in distinct ways, with both', - 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', - 'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya aki no tori, nemuri no' - ], - "fp8": [ - 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', - 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system made up of several basic components that work together to enable it to', - 'Zeta-5, a highly advanced robot designed for menial labor, had never experienced anything like', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here', - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya kotori wa mushi o tsuk' - ] - }, - "meta-llama/Meta-Llama-3-8B-Instruct": { - "auto": [ - 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', - 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne', - 'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu' - ], - "fp8": [ - 'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained', - 'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ', - 'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.', - 'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne', - 'In the year 2154, robotics engineer Dr. Rachel Kim had spent years perfecting her latest', - 'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The', - 'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of', - 'Here are the translations:\n\n**Japanese:** (Haya tori, mushi o tsukamu' - ] - }, -} +os.environ["TOKENIZERS_PARALLELISM"] = "true" -# This test compares against golden strings for exact match since -# there is no baseline implementation to compare against -# and is unstable w.r.t specifics of the fp8 implementation or -# the hardware being run on. -# Disabled to prevent it from breaking the build -@pytest.mark.skip( - reason= - "Prevent unstable test based on golden strings from breaking the build.") @pytest.mark.skipif(not is_quant_method_supported("fp8"), reason="fp8 is not supported on this GPU type.") -@pytest.mark.parametrize("model_name", MODELS) -@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"]) -def test_models(example_prompts, model_name, kv_cache_dtype) -> None: - model = LLM(model=model_name, - max_model_len=MAX_MODEL_LEN, - trust_remote_code=True, - enforce_eager=True, - quantization="fp8", - kv_cache_dtype=kv_cache_dtype) +@pytest.mark.parametrize( + "kv_cache_dtype,base_model,test_model,scale_path", + [ + # Test FP8 checkpoint w. fp8_e4m3 kv-cache scaling factors. + ("fp8_e4m3", "meta-llama/Meta-Llama-3-8B-Instruct", + "nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV", None), + # Test FP16 checkpoint w. fp8_e5m2 kv-cache. + ("fp8_e5m2", "meta-llama/Meta-Llama-3-8B-Instruct", + "meta-llama/Meta-Llama-3-8B-Instruct", None), + # Test FP16 checkpoint w. fp8_e4m3 kv-cache scaling factors in json. + ("fp8_e4m3", "meta-llama/Llama-2-7b-chat-hf", + "meta-llama/Llama-2-7b-chat-hf", + "./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json") + ]) +# Due to low-precision numerical divergence, we only test logprob of 4 tokens +@pytest.mark.parametrize("max_tokens", [4]) +@pytest.mark.parametrize("enforce_eager", [False, True]) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"]) +# NOTE: Increasing this in this suite will fail CI because we currently cannot +# reset distributed env properly. Use a value > 1 just when you test. +@pytest.mark.parametrize("tensor_parallel_size", [1]) +# Due to low-precision numerical divergence, this test is too sensitive for +# the async postprocessor +@pytest.mark.parametrize("disable_async_output_proc", [True]) +def test_models( + vllm_runner, + example_prompts, + kv_cache_dtype: str, + base_model: str, + test_model: str, + scale_path: Optional[str], + max_tokens: int, + enforce_eager: bool, + backend: str, + tensor_parallel_size: int, + disable_async_output_proc: bool, + monkeypatch, +) -> None: + """ + Only checks log probs match to cover the discrepancy in + numerical sensitive kernels. + """ + override_backend_env_variable(monkeypatch, backend) + + MAX_MODEL_LEN = 1024 + NUM_LOG_PROBS = 8 + + with vllm_runner( + base_model, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + kv_cache_dtype="auto", + disable_async_output_proc=disable_async_output_proc, + ) as vllm_model: + baseline_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, NUM_LOG_PROBS) - tokenizer = AutoTokenizer.from_pretrained(model_name) - formatted_prompts = [ - tokenizer.apply_chat_template([{ - "role": "user", - "content": prompt - }], - tokenize=False, - add_generation_prompt=True) - for prompt in example_prompts - ] + extra_kwargs = {} + if scale_path is not None: + extra_kwargs["quantization_param_path"] = scale_path - params = SamplingParams(max_tokens=20, temperature=0) - generations: List[str] = [] - # Note: these need to be run 1 at a time due to numerical precision, - # since the expected strs were generated this way. - for prompt in formatted_prompts: - outputs = model.generate(prompt, params) - generations.append(outputs[0].outputs[0].text) - del model + with vllm_runner( + test_model, + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=tensor_parallel_size, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + disable_async_output_proc=disable_async_output_proc, + **extra_kwargs, + ) as vllm_model: + test_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, NUM_LOG_PROBS) - print(model_name, kv_cache_dtype, generations) - expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype] - for i in range(len(example_prompts)): - generated_str = generations[i] - expected_str = expected_strs[i] - assert expected_str == generated_str, ( - f"Test{i}:\nExpected: {expected_str!r}\nvLLM: {generated_str!r}") + check_logprobs_close( + outputs_0_lst=baseline_outputs, + outputs_1_lst=test_outputs, + name_0="fp16_kv_cache", + name_1="fp8_kv_cache", + ) diff --git a/tests/models/test_granite.py b/tests/models/test_granite.py new file mode 100644 index 0000000000000..2435b5dc3ff88 --- /dev/null +++ b/tests/models/test_granite.py @@ -0,0 +1,49 @@ +"""Compare the outputs of HF and vLLM for Granite models using greedy sampling. + +Run `pytest tests/models/test_granite.py`. +""" +import importlib.metadata + +import pytest + +from .utils import check_logprobs_close + +TRANSFORMERS_VERSION = tuple( + map(int, + importlib.metadata.version("transformers").split("."))) + +MODELS = [ + "ibm/PowerLM-3b", +] + + +# GraniteForCausalLM will be in transformers >= 4.45 +@pytest.mark.skipif(TRANSFORMERS_VERSION < (4, 45), + reason="granite model test requires transformers >= 4.45") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + # TODO(sang): Sliding window should be tested separately. + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/models/test_intern_vit.py b/tests/models/test_intern_vit.py index e980446ff3570..816f846f69bae 100644 --- a/tests/models/test_intern_vit.py +++ b/tests/models/test_intern_vit.py @@ -6,8 +6,6 @@ from huggingface_hub import snapshot_download from transformers import AutoConfig, AutoModel, CLIPImageProcessor -from vllm.model_executor.models.intern_vit import InternVisionModel - from ..conftest import _ImageAssets, cleanup pytestmark = pytest.mark.vlm @@ -49,6 +47,7 @@ def run_intern_vit_test( for pixel_value in pixel_values ] + from vllm.model_executor.models.intern_vit import InternVisionModel vllm_model = InternVisionModel(config) vllm_model.load_weights(hf_model.state_dict().items()) diff --git a/tests/models/test_internvl.py b/tests/models/test_internvl.py index 243bc857c88de..fa3369dc53345 100644 --- a/tests/models/test_internvl.py +++ b/tests/models/test_internvl.py @@ -1,18 +1,16 @@ import types -from typing import List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type, Union import pytest import torch from PIL.Image import Image from transformers import AutoConfig -from vllm.model_executor.models.internvl import (IMG_CONTEXT, IMG_END, - IMG_START, - image_to_pixel_values) from vllm.multimodal.utils import rescale_image_size from vllm.utils import is_cpu -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) from .utils import check_logprobs_close pytestmark = pytest.mark.vlm @@ -23,6 +21,7 @@ "cherry_blossom": "<|im_start|>User\n\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501 }) +HF_MULTIIMAGE_IMAGE_PROMPT = "<|im_start|>User\nImage-1: \nImage-2: \nDescribe the two images in detail.<|im_end|>\n<|im_start|>Assistant\n" # noqa: E501 models = [ "OpenGVLab/InternVL2-1B", @@ -33,35 +32,6 @@ ] -class InternVLProcessor: - """A simple processor for InternVL2 HF model which misses a processor.""" - - def __init__(self, hf_runner: HfRunner): - self.num_image_token = hf_runner.model.num_image_token - self.tokenizer = hf_runner.tokenizer - self.dtype = hf_runner.model.dtype - - self.config = AutoConfig.from_pretrained(hf_runner.model_name) - self.vision_config = self.config.vision_config - self.use_thumbnail = self.config.use_thumbnail - self.min_num = self.config.min_dynamic_patch - self.max_num = self.config.max_dynamic_patch - self.image_size = self.vision_config.image_size - - def __call__(self, text: str, images: Image, **kwargs): - pixel_values = image_to_pixel_values(images, self.image_size, - self.min_num, self.max_num, - self.use_thumbnail).to(self.dtype) - num_patches_list = [pixel_values.shape[0]] - for num_patches in num_patches_list: - context_tokens = IMG_CONTEXT * self.num_image_token * num_patches - image_tokens = IMG_START + context_tokens + IMG_END - text = text.replace('', image_tokens, 1) - prompt = self.tokenizer(text, return_tensors="pt") - prompt.update({"pixel_values": pixel_values}) - return prompt - - # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py def generate( self, @@ -96,13 +66,13 @@ def generate( def run_test( hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], - image_assets: _ImageAssets, + inputs: List[Tuple[List[str], PromptImageInput]], model: str, *, - size_factors: List[float], dtype: str, max_tokens: int, num_logprobs: int, + mm_limit: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, ): @@ -115,22 +85,56 @@ def run_test( Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ - images = [asset.pil_image for asset in image_assets] - - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it # will hurt multiprocessing backend with fork method (the default method). + class InternVLProcessor: + """A simple processor for InternVL2 which misses a processor.""" + + def __init__(self, hf_runner: HfRunner): + self.num_image_token = hf_runner.model.num_image_token + self.tokenizer = hf_runner.tokenizer + self.dtype = hf_runner.model.dtype + + self.config = AutoConfig.from_pretrained(hf_runner.model_name) + self.vision_config = self.config.vision_config + self.use_thumbnail = self.config.use_thumbnail + self.min_num = self.config.min_dynamic_patch + self.max_num = self.config.max_dynamic_patch + self.image_size = self.vision_config.image_size + + def __call__(self, text: str, images: Union[Image, List[Image]], + **kwargs): + from vllm.model_executor.models.internvl import ( + IMG_CONTEXT, IMG_END, IMG_START, image_to_pixel_values) + images = [images] if isinstance(images, Image) else images + pixel_values = [ + image_to_pixel_values(image, self.image_size, self.min_num, + self.max_num, + self.use_thumbnail).to(self.dtype) + for image in images + ] + num_patches_list = [ + pixel_value.shape[0] for pixel_value in pixel_values + ] + pixel_values = torch.cat(pixel_values, dim=0) + for num_patches in num_patches_list: + context_tokens = IMG_CONTEXT * self.num_image_token \ + * num_patches + image_tokens = IMG_START + context_tokens + IMG_END + text = text.replace('', image_tokens, 1) + prompt = self.tokenizer(text, return_tensors="pt") + prompt.update({"pixel_values": pixel_values}) + return prompt + # max_model_len should be greater than image_feature_size with vllm_runner(model, max_model_len=4096, dtype=dtype, + limit_mm_per_prompt={"image": mm_limit}, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, enforce_eager=True) as vllm_model: @@ -139,7 +143,7 @@ def run_test( max_tokens, num_logprobs=num_logprobs, images=images) - for prompts, images in inputs_per_image + for prompts, images in inputs ] with hf_runner(model, dtype=dtype) as hf_model: @@ -157,7 +161,7 @@ def run_test( num_logprobs=num_logprobs, images=hf_images, eos_token_id=eos_token_id) - for prompts, hf_images in inputs_per_image + for prompts, hf_images in inputs ] for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, @@ -265,15 +269,64 @@ def run_awq_test( @torch.inference_mode() def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, dtype: str, max_tokens: int, num_logprobs: int) -> None: + images = [asset.pil_image for asset in image_assets] + + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + run_test( hf_runner, vllm_runner, - image_assets, + inputs_per_image, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + mm_limit=1, + tensor_parallel_size=1, + ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.5, 0.75, 1.0], + ], +) +@pytest.mark.parametrize("dtype", [target_dtype]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +@torch.inference_mode() +def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, + size_factors, dtype: str, max_tokens: int, + num_logprobs: int) -> None: + images = [asset.pil_image for asset in image_assets] + + inputs_per_case = [ + ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], + [[rescale_image_size(image, factor) for image in images] + for factor in size_factors]) + ] + + run_test( + hf_runner, + vllm_runner, + inputs_per_case, model, - size_factors=size_factors, dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, + mm_limit=2, tensor_parallel_size=1, ) diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index 93634f245cee7..84ca23f6222a9 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type, overload import pytest from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, @@ -8,11 +8,14 @@ from vllm.sequence import SampleLogprobs from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) from .utils import check_logprobs_close pytestmark = pytest.mark.vlm +_LIMIT_IMAGE_PER_PROMPT = 4 + HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": "USER: \nWhat's the content of the image?\nASSISTANT:", @@ -52,6 +55,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, return hf_output_ids, hf_output_str, out_logprobs +@overload def run_test( hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], @@ -64,6 +68,78 @@ def run_test( num_logprobs: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, +): + ... + + +@overload +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + sizes: List[Tuple[int, int]], + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + ... + + +def run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + image_assets: _ImageAssets, + model: str, + *, + size_factors: Optional[List[float]] = None, + sizes: Optional[List[Tuple[int, int]]] = None, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + images = [asset.pil_image for asset in image_assets] + + if size_factors is not None: + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + elif sizes is not None: + inputs_per_image = [( + [prompt for _ in sizes], + [image.resize(size) for size in sizes], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + else: + raise ValueError("You must provide either `size_factors` or `sizes`") + + _run_test(hf_runner, + vllm_runner, + inputs_per_image, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend) + + +def _run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + inputs: List[Tuple[List[str], PromptImageInput]], + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, ): """Inference result should be the same between hf and vllm. @@ -85,13 +161,6 @@ def run_test( else: mantis_processor = None - images = [asset.pil_image for asset in image_assets] - - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] - # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. # if we run HF first, the cuda initialization will be done and it @@ -100,15 +169,18 @@ def run_test( # max_model_len should be greater than image_feature_size with vllm_runner(model, dtype=dtype, + max_model_len=4096, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: + enforce_eager=True, + limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT + }) as vllm_model: vllm_outputs_per_image = [ vllm_model.generate_greedy_logprobs(prompts, max_tokens, num_logprobs=num_logprobs, images=images) - for prompts, images in inputs_per_image + for prompts, images in inputs ] if mantis_processor is not None: @@ -131,7 +203,7 @@ def process(hf_inputs: BatchEncoding): max_tokens, num_logprobs=num_logprobs, images=images) - for prompts, images in inputs_per_image + for prompts, images in inputs ] for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, @@ -179,3 +251,65 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, num_logprobs=num_logprobs, tensor_parallel_size=1, ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets, + model, dtype, max_tokens, + num_logprobs) -> None: + stop_sign = image_assets[0].pil_image + cherry_blossom = image_assets[1].pil_image + + inputs = [( + [ + "USER: \nDescribe 2 images.\nASSISTANT:", + "USER: \nDescribe 2 images.\nASSISTANT:", + "USER: \nDescribe 4 images.\nASSISTANT:", # noqa: E501 + "USER: \nWhat is the season?\nASSISTANT:", + ], + [ + [stop_sign, cherry_blossom], + # Images with different sizes and aspect-ratios + [ + rescale_image_size(stop_sign, 0.1), + stop_sign, + ], + [ + stop_sign, + rescale_image_size(stop_sign, 0.25), + cherry_blossom.resize((183, 488)), + cherry_blossom.resize((488, 183)) + ], + cherry_blossom, + ])] + + _run_test( + hf_runner, + vllm_runner, + inputs, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) + + +@pytest.mark.parametrize("model", models) +def test_context_length_too_short(vllm_runner, image_assets, model): + images = [asset.pil_image for asset in image_assets] + + with pytest.raises(ValueError, match="too long to fit into the model"): + vllm_model = vllm_runner( + model, + max_model_len=128, # LLaVA has a feature size of 576 + enforce_eager=True, + ) + + with vllm_model: + vllm_model.generate_greedy([HF_IMAGE_PROMPTS[0]], + max_tokens=1, + images=[images[0]]) diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 4965354c0016b..0741174497e32 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -41,3 +41,43 @@ def test_models( name_0="hf", name_1="vllm", ) + + +@pytest.mark.parametrize("model", MODELS[1:]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_mistral_format( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + with vllm_runner( + model, + dtype=dtype, + tokenizer_mode="auto", + load_format="safetensors", + config_format="hf", + ) as hf_format_model: + hf_format_outputs = hf_format_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner( + model, + dtype=dtype, + tokenizer_mode="mistral", + load_format="mistral", + config_format="mistral", + ) as mistral_format_model: + mistral_format_outputs = mistral_format_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=hf_format_outputs, + outputs_1_lst=mistral_format_outputs, + name_0="hf", + name_1="mistral", + ) diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index e416a85b8962a..6ecbf07a08b7c 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -1,16 +1,15 @@ import os import re -from typing import List, Optional, Tuple, Type, Union +from typing import List, Optional, Tuple, Type import pytest -from PIL import Image from transformers import AutoTokenizer from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs from vllm.utils import is_cpu, is_hip -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner +from ..conftest import IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner from .utils import check_logprobs_close pytestmark = pytest.mark.vlm @@ -60,8 +59,7 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, def run_test( hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], - inputs: List[Tuple[List[str], Union[List[Image.Image], - List[List[Image.Image]]]]], + inputs: List[Tuple[List[str], PromptImageInput]], model: str, *, dtype: str, diff --git a/tests/models/test_phimoe.py b/tests/models/test_phimoe.py new file mode 100644 index 0000000000000..2fb2eecc94672 --- /dev/null +++ b/tests/models/test_phimoe.py @@ -0,0 +1,111 @@ +"""Compare the outputs of HF and vLLM for moe models using greedy sampling. + +Run `pytest tests/models/test_phimoe.py`. +""" +import pytest +import torch + +from vllm.utils import is_cpu + +from .utils import check_logprobs_close + +MODELS = [ + "microsoft/Phi-3.5-MoE-instruct", +] + + +def test_phimoe_routing_function(): + from vllm.model_executor.models.phimoe import phimoe_routing_function + test_case = { + 0: { + "hidden_states": + torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], + dtype=torch.float32, + requires_grad=False).view(4, 2), + "gating_output": + torch.tensor([0.1, 0.2, 0.3, 0.4], + dtype=torch.float32, + requires_grad=False), + "topk": + 2, + "renormalize": + False, + }, + 1: { + "hidden_states": + torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], + dtype=torch.float32, + requires_grad=False).view(4, 2), + "gating_output": + torch.tensor([0.4, 0.2, 0.3, 0.4], + dtype=torch.float32, + requires_grad=False), + "topk": + 2, + "renormalize": + False, + } + } + + ground_truth = { + 0: { + "topk_weights": + torch.tensor([1., 1.], dtype=torch.float32, requires_grad=False), + "topk_ids": + torch.tensor([3, 2], dtype=torch.long, requires_grad=False), + }, + 1: { + "topk_weights": + torch.tensor([0.5, 1.], dtype=torch.float32, requires_grad=False), + "topk_ids": + torch.tensor([0, 3], dtype=torch.long, requires_grad=False), + } + } + + for test_id in test_case: + topk_weights, topk_ids = phimoe_routing_function(**test_case[test_id]) + assert torch.allclose(topk_weights, + ground_truth[test_id]["topk_weights"]) + assert torch.equal(topk_ids, ground_truth[test_id]["topk_ids"]) + + +def get_gpu_memory(): + try: + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + gpu_memory = props.total_memory / (1024**3) + return gpu_memory + except Exception: + return 0 + + +@pytest.mark.skipif(condition=is_cpu(), + reason="This test takes a lot time to run on CPU, " + "and vllm CI's disk space is not enough for this model.") +@pytest.mark.skipif(condition=get_gpu_memory() < 100, + reason="Skip this test if GPU memory is insufficient.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/models/test_qwen.py b/tests/models/test_qwen.py index 0f974fcc1885c..05f5cbf8c3435 100644 --- a/tests/models/test_qwen.py +++ b/tests/models/test_qwen.py @@ -1,48 +1,165 @@ -from typing import Type +import pathlib +from typing import List, Optional, Type import pytest -from ..conftest import HfRunner, VllmRunner +from vllm.multimodal.utils import rescale_image_size + +from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets from .utils import check_logprobs_close -models = ["qwen/qwen-vl"] +pytestmark = pytest.mark.vlm +text_only_models = [ + "Qwen/Qwen-7B-Chat" # Has no visual component +] -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [32]) -@pytest.mark.parametrize("num_logprobs", [5]) -@pytest.mark.parametrize("model", models) -def test_text_only_qwen_model( +multimodal_models = ["Qwen/Qwen-VL"] + +HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ + "stop_sign": + "Picture 1: \nWhat's the content of the image?: ", + "cherry_blossom": + "Picture 1: \nWhat is the season?: ", +}) + + +### Tests for multimodal Qwen models +def run_test( + tmp_path: pathlib.PosixPath, hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], - example_prompts, + image_assets: _ImageAssets, model: str, *, + size_factors: List[float], dtype: str, max_tokens: int, num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, ): - # This test checks language inputs only, since the visual component - # for qwen-vl is still unsupported in VLLM. In the near-future, the - # implementation and this test will be extended to consider - # visual inputs as well. + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test is under tests/images. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding MultiModalConfig as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. + """ + images = [asset.pil_image for asset in image_assets] + + # Export the images to a tempdir and substitute it into the hf prompt; + # the contents between / will be ignored by VLLM, but the + # transformers implementation for the visual transformer parses this to + # reload it in the forward call; the contents are treated as a URL or a + # local path. + for idx, asset in enumerate(image_assets): + image_tmp_path = tmp_path / f"{asset.name}.jpg" + asset.pil_image.save(image_tmp_path) + HF_IMAGE_PROMPTS[idx] = HF_IMAGE_PROMPTS[idx].replace( + "", f"{image_tmp_path}") + + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + + # NOTE: take care of the order. run vLLM first, and then run HF. + # vLLM needs a fresh new process without cuda initialization. + # if we run HF first, the cuda initialization will be done and it + # will hurt multiprocessing backend with fork method (the default method). + + # max_model_len should be greater than image_feature_size + # Qwen encodes images into a fixed content size of 256 + with vllm_runner(model, + max_model_len=300, + max_num_seqs=1, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True) as vllm_model: + vllm_outputs_per_image = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs_per_image + ] + with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy_logprobs_limit( - example_prompts, - max_tokens, - num_logprobs=num_logprobs, + hf_outputs_per_image = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs_per_image + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, + vllm_outputs_per_image): + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", ) + +@pytest.mark.parametrize("model", multimodal_models) +@pytest.mark.parametrize( + "size_factors", + [ + # No image + [], + # Single-scale + [1.0], + # Single-scale, batched + [1.0, 1.0, 1.0], + # Multi-scale + [0.25, 0.5, 1.0], + ], +) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [8]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_multimodal_models(tmp_path, hf_runner, vllm_runner, image_assets, + model, size_factors, dtype, max_tokens, + num_logprobs) -> None: + run_test( + tmp_path, + hf_runner, + vllm_runner, + image_assets, + model, + size_factors=size_factors, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) + + +# Ensure that a text-only Qwen model can still be loaded and +# used for inference in VLLM without throwing. +@pytest.mark.parametrize("model", text_only_models) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_text_only_qwen_model_can_be_loaded_and_run( + vllm_runner: Type[VllmRunner], + example_prompts, + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, +): with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_greedy_logprobs( + vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs=num_logprobs, ) - - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) diff --git a/tests/models/test_ultravox.py b/tests/models/test_ultravox.py index 98de10aa08408..e98db9b65f484 100644 --- a/tests/models/test_ultravox.py +++ b/tests/models/test_ultravox.py @@ -1,11 +1,9 @@ from typing import List, Optional, Tuple, Type -import librosa import numpy as np import pytest from transformers import AutoModel, AutoTokenizer, BatchEncoding -from vllm.assets.audio import AudioAsset from vllm.sequence import SampleLogprobs from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE @@ -18,36 +16,32 @@ AudioTuple = Tuple[np.ndarray, int] +VLLM_PLACEHOLDER = "<|reserved_special_token_0|>" +HF_PLACEHOLDER = "<|audio|>" + @pytest.fixture(scope="session") -def audio_and_sample_rate(): - return AudioAsset("mary_had_lamb").audio_and_sample_rate +def audio_assets(): + from vllm.assets.audio import AudioAsset + return [AudioAsset("mary_had_lamb"), AudioAsset("winning_call")] -@pytest.fixture -def prompts_and_audios(audio_and_sample_rate): - tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) +@pytest.fixture(scope="module", params=("mary_had_lamb", "winning_call")) +def audio(request): + from vllm.assets.audio import AudioAsset + return AudioAsset(request.param) - vllm_placeholder = "<|reserved_special_token_0|>" - hf_placeholder = "<|audio|>" - question = "What's in the audio?" - vllm_prompt = tokenizer.apply_chat_template( - [{ - 'role': 'user', - 'content': f"{vllm_placeholder}\n{question}" - }], - tokenize=False, - add_generation_prompt=True) - hf_prompt = tokenizer.apply_chat_template( - [{ - 'role': 'user', - 'content': f"{hf_placeholder}\n{question}" - }], - tokenize=False, - add_generation_prompt=True) +def _get_prompt(audio_count, question, placeholder): + tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) + placeholder = f"{placeholder}\n" * audio_count - return [(vllm_prompt, hf_prompt, audio_and_sample_rate)] + return tokenizer.apply_chat_template([{ + 'role': 'user', + 'content': f"{placeholder}{question}" + }], + tokenize=False, + add_generation_prompt=True) def vllm_to_hf_output(vllm_output: Tuple[List[int], str, @@ -109,6 +103,7 @@ def process(hf_inputs: BatchEncoding): dtype=dtype, postprocess_inputs=process, auto_cls=AutoModel) as hf_model: + import librosa hf_outputs_per_audio = [ hf_model.generate_greedy_logprobs_limit( @@ -134,15 +129,71 @@ def process(hf_inputs: BatchEncoding): ) +def run_multi_audio_test( + vllm_runner: Type[VllmRunner], + prompts_and_audios: List[Tuple[str, List[AudioTuple]]], + model: str, + *, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): + with vllm_runner(model, + dtype=dtype, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend, + enforce_eager=True, + limit_mm_per_prompt={ + "audio": + max((len(audio) for _, audio in prompts_and_audios)) + }) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + [prompt for prompt, _ in prompts_and_audios], + max_tokens, + num_logprobs=num_logprobs, + audios=[audios for _, audios in prompts_and_audios]) + + # The HuggingFace model doesn't support multiple audios yet, so + # just assert that some tokens were generated. + assert all(tokens for tokens, *_ in vllm_outputs) + + @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("num_logprobs", [5]) -def test_models(hf_runner, vllm_runner, prompts_and_audios, dtype: str, - max_tokens: int, num_logprobs: int) -> None: +def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int, + num_logprobs: int) -> None: + + vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER) + hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER) run_test( hf_runner, vllm_runner, - prompts_and_audios, + [(vllm_prompt, hf_prompt, audio.audio_and_sample_rate)], + MODEL_NAME, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) + + +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models_with_multiple_audios(vllm_runner, audio_assets, dtype: str, + max_tokens: int, + num_logprobs: int) -> None: + + vllm_prompt = _get_prompt(len(audio_assets), + "Describe each of the audios above.", + VLLM_PLACEHOLDER) + run_multi_audio_test( + vllm_runner, + [(vllm_prompt, [audio.audio_and_sample_rate + for audio in audio_assets])], MODEL_NAME, dtype=dtype, max_tokens=max_tokens, diff --git a/tests/models/utils.py b/tests/models/utils.py index ff29a0ae81d6e..93ec03995094b 100644 --- a/tests/models/utils.py +++ b/tests/models/utils.py @@ -1,7 +1,7 @@ import warnings from typing import Dict, List, Optional, Sequence, Tuple, Union -from vllm.sequence import SampleLogprobs +from vllm.sequence import Logprob, SampleLogprobs TokensText = Tuple[List[int], str] @@ -38,34 +38,39 @@ def check_outputs_equal( float]], SampleLogprobs]]] +# Allow for tokens to be represented as str's rather than IDs +TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]], + List[Dict[str, + Logprob]]]]] + def check_logprobs_close( *, - outputs_0_lst: Sequence[TokensTextLogprobs], - outputs_1_lst: Sequence[TokensTextLogprobs], + outputs_0_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]], + outputs_1_lst: Sequence[Union[TokensTextLogprobs, TextTextLogprobs]], name_0: str, name_1: str, num_outputs_0_skip_tokens: int = 0, warn_on_mismatch: bool = True, -): - """ - Compare the logprobs of two sequences generated by different models, + always_check_logprobs: bool = False, +) -> None: + """Compare the logprobs of two sequences generated by different models, which should be similar but not necessarily equal. - Arguments: - - * outputs_0_lst: First sequence to compare - * outputs_0_lst: Second sequence to compare - * name_0: sequence #0 name - * name_1: sequence #1 name - * num_outputs_0_skip_tokens: If > 0, specifies the number of initial + Args: + outputs_0_lst: First sequence to compare + outputs_0_lst: Second sequence to compare + name_0: sequence #0 name + name_1: sequence #1 name + num_outputs_0_skip_tokens: If > 0, specifies the number of initial sequence #0 tokens & logprobs to discard before comparison, i.e. all of sequence #1 will be compared to sequence #0 beginning at index num_outputs_0_skip_tokens - * warn_on_mismatch: Issue a warning if there is token-wise or text-wise + warn_on_mismatch: Issue a warning if there is token-wise or text-wise mismatch between the two sequences + always_check_logprobs: If true, check logprobs even when tokens match """ assert len(outputs_0_lst) == len(outputs_1_lst) @@ -94,8 +99,12 @@ def check_logprobs_close( for idx, (output_id_0, output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): - # If generated tokens don't match, then - if output_id_0 != output_id_1: + is_tok_mismatch = output_id_0 != output_id_1 + + # If generated tokens don't match + # or it is desired to always check logprobs, + # then + if is_tok_mismatch or always_check_logprobs: logprobs_elem_0 = logprobs_0[idx] logprobs_elem_1 = logprobs_1[idx] @@ -111,7 +120,7 @@ def check_logprobs_close( assert output_id_0 in logprobs_elem_1, fail_msg assert output_id_1 in logprobs_elem_0, fail_msg - if warn_on_mismatch: + if warn_on_mismatch and is_tok_mismatch: with warnings.catch_warnings(): # This ensures that repeated warnings are shown # in the output, not just the first occurrence diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index ad99d70d7417c..0cbe8371e235a 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -1,10 +1,12 @@ # Test the AsyncLLMEngine with multi-step-decoding -from typing import List +from typing import List, Optional import pytest -from ..utils import RemoteOpenAIServer +from ..models.utils import check_logprobs_close +from ..utils import (completions_with_server_args, get_client_text_generations, + get_client_text_logprob_generations) MODELS = [ "JackFram/llama-160m", @@ -23,22 +25,6 @@ ] -async def completions_with_server_args(prompts: List[str], model_name: str, - server_cli_args: List[str]): - - outputs = None - with RemoteOpenAIServer(model_name, server_cli_args) as server: - async with server.get_async_client() as client: - outputs = await client.completions.create(model=model_name, - prompt=prompts, - temperature=0, - stream=False, - max_tokens=5) - assert outputs is not None - - return outputs - - @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize(("tp_size, pp_size"), [ (1, 1), @@ -47,10 +33,43 @@ async def completions_with_server_args(prompts: List[str], model_name: str, @pytest.mark.parametrize("eager_mode", [False, True]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.parametrize("num_logprobs", [None, 5]) +@pytest.mark.parametrize("is_async", [False, True]) @pytest.mark.asyncio -async def test_multi_step(example_prompts, model: str, tp_size: int, - pp_size: int, eager_mode: int, - num_scheduler_steps: int, num_prompts: int): +async def test_multi_step( + example_prompts, + model: str, + tp_size: int, + pp_size: int, + eager_mode: int, + num_scheduler_steps: int, + num_prompts: int, + is_async: bool, + num_logprobs: Optional[int], +) -> None: + """Test vLLM engine with multi-step scheduling in an OpenAI-protocol + client/server environment. + + Set up an engine with single-step scheduling as a ground-truth reference. + + Send a completions API request to both engines with the same prompts. + + Validate: + * Generated tokens match + * Generated logprobs are all very close + + Args: + example_prompts: test fixture providing example prompts + model: model under test (same for single- and multi-step engines) + tp_size: degree of tensor-parallelism + pp_size: degree of pipeline-parallelism + eager_mode + num_scheduler_steps: for multi-step scheduling, GPU-side steps per + GPU -> CPU output transfer + num_prompts: number of example prompts under test + num_logprobs: corresponds to the `logprobs` argument to the OpenAI + completions endpoint; `None` -> no logprobs + """ prompts = example_prompts if len(prompts) < num_prompts: @@ -62,9 +81,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int, ms_server_args = DEFAULT_SERVER_ARGS + \ ["--num-scheduler-steps", f"{num_scheduler_steps}"] - # Disable output proc callback as its not supported - # with multi-step right now - ms_server_args += ["--disable-async-output-proc"] + if not is_async: + ms_server_args += ["--disable-async-output-proc"] + if eager_mode: ms_server_args.append("--enforce-eager") @@ -75,14 +94,36 @@ async def test_multi_step(example_prompts, model: str, tp_size: int, str(pp_size), ] + # Spin up client/server & issue completion API requests. + # Default `max_wait_seconds` is 240 but was empirically + # was raised 3x to 720 *just for this test* due to + # observed timeouts in GHA CI ref_completions = await completions_with_server_args( - prompts, model, server_args + distributed_args) + prompts, + model, + server_args + distributed_args, + num_logprobs, + max_wait_seconds=5 * 240) test_completions = await completions_with_server_args( - prompts, model, ms_server_args + distributed_args) - - def get_text_generations(completions): - return [x.text for x in completions.choices] - - ref_generations = get_text_generations(ref_completions) - test_generations = get_text_generations(test_completions) + prompts, + model, + ms_server_args + distributed_args, + num_logprobs, + max_wait_seconds=5 * 240) + + # Assert multi-step scheduling produces identical tokens + # to single-step scheduling. + ref_generations = get_client_text_generations(ref_completions) + test_generations = get_client_text_generations(test_completions) assert ref_generations == test_generations + + # Assert multi-step scheduling produces nearly-identical logprobs + # to single-step scheduling. + ref_text_logprobs = get_client_text_logprob_generations(ref_completions) + test_text_logprobs = get_client_text_logprob_generations(test_completions) + check_logprobs_close( + outputs_0_lst=ref_text_logprobs, + outputs_1_lst=test_text_logprobs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index 36f610ba74f05..24ebb60a9cbfd 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -1,8 +1,10 @@ # Test the LLMEngine with multi-step-decoding +from typing import Optional + import pytest -from ..models.utils import check_outputs_equal +from ..models.utils import check_logprobs_close, check_outputs_equal MODELS = [ "JackFram/llama-160m", @@ -18,10 +20,45 @@ @pytest.mark.parametrize("enforce_eager", [True]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) -def test_multi_step_llm(hf_runner, vllm_runner, example_prompts, model: str, - dtype: str, tp_size: int, max_tokens: int, - enforce_eager: int, num_scheduler_steps: int, - num_prompts: int) -> None: +@pytest.mark.parametrize("num_logprobs", [None, 5]) +def test_multi_step_llm( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + tp_size: int, + max_tokens: int, + enforce_eager: int, + num_scheduler_steps: int, + num_prompts: int, + num_logprobs: Optional[int], +) -> None: + """Test vLLM engine with multi-step scheduling via sync LLM Engine. + + Set up a HuggingFace (HF) transformers model as a ground-truth reference. + + Prompt them with the same example prompts. + + Validate: + * Generated tokens match + * Generated logprobs are all very close + + Args: + hf_runner: HF transformers model runner fixture + vllm_runner: vLLM model runner fixture + example_prompts: test fixture providing example prompts + model: model under test (same for single- and multi-step engines) + dtype: tensor datatype for engine to utilize + tp_size: degree of tensor-parallelism + max_tokens: the maximum number of tokens to generate + enforce_eager + num_scheduler_steps: for multi-step scheduling, GPU-side steps per + GPU -> CPU output transfer + num_prompts: number of example prompts under test + num_logprobs: corresponds to the `logprobs` argument to the OpenAI + completions endpoint; `None` -> 1 logprob returned. + """ prompts = example_prompts if len(prompts) < num_prompts: @@ -29,21 +66,37 @@ def test_multi_step_llm(hf_runner, vllm_runner, example_prompts, model: str, prompts = prompts[:num_prompts] assert len(prompts) == num_prompts - with vllm_runner(model, - dtype=dtype, - enforce_eager=enforce_eager, - gpu_memory_utilization=0.7, - tensor_parallel_size=tp_size, - use_v2_block_manager=True, - num_scheduler_steps=num_scheduler_steps) as vllm_model: - vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens) + with vllm_runner( + model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + tensor_parallel_size=tp_size, + use_v2_block_manager=True, + num_scheduler_steps=num_scheduler_steps, + ) as vllm_model: + vllm_outputs = (vllm_model.generate_greedy(prompts, max_tokens) + if num_logprobs is None else + vllm_model.generate_greedy_logprobs( + prompts, max_tokens, num_logprobs)) with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_greedy(prompts, max_tokens) - - check_outputs_equal( - outputs_0_lst=hf_outputs, - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) + hf_outputs = (hf_model.generate_greedy(prompts, max_tokens) + if num_logprobs is None else + hf_model.generate_greedy_logprobs_limit( + prompts, max_tokens, num_logprobs)) + + if num_logprobs is None: + check_outputs_equal( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + else: + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/multimodal/test_base.py b/tests/multimodal/test_base.py index f19a0f33fe067..e9562d2048f06 100644 --- a/tests/multimodal/test_base.py +++ b/tests/multimodal/test_base.py @@ -81,3 +81,15 @@ def test_multimodal_input_batch_multiple_batchable_lists(): result, {"image": torch.stack([torch.stack([a, b]), torch.stack([c, d])])}) + + +def test_multimodal_input_batch_mixed_stacking_depths(): + a = torch.rand([1, 2, 3]) + b = torch.rand([1, 3, 3]) + c = torch.rand([1, 4, 3]) + + result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}]) + assert_multimodal_inputs_equal(result, {"image": [[a, b], c.unsqueeze(0)]}) + + result = MultiModalInputs.batch([{"image": [a]}, {"image": [b, c]}]) + assert_multimodal_inputs_equal(result, {"image": [a.unsqueeze(0), [b, c]]}) diff --git a/tests/quantization/test_bitsandbytes.py b/tests/quantization/test_bitsandbytes.py index b760e9ccb6b74..3f0c6cbc051a7 100644 --- a/tests/quantization/test_bitsandbytes.py +++ b/tests/quantization/test_bitsandbytes.py @@ -2,85 +2,115 @@ Run `pytest tests/quantization/test_bitsandbytes.py`. ''' + +import gc + import pytest import torch from tests.quantization.utils import is_quant_method_supported -from vllm import SamplingParams -models_to_test = [ +models_4bit_to_test = [ ('huggyllama/llama-7b', 'quantize model inflight'), - ('lllyasviel/omost-llama-3-8b-4bits', 'read pre-quantized model'), ] +models_pre_qaunt_4bit_to_test = [ + ('lllyasviel/omost-llama-3-8b-4bits', + 'read pre-quantized 4-bit NF4 model'), + ('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed', + 'read pre-quantized 4-bit FP4 model'), +] + +models_pre_quant_8bit_to_test = [ + ('meta-llama/Llama-Guard-3-8B-INT8', 'read pre-quantized 8-bit model'), +] + + +@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), + reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.parametrize("model_name, description", models_4bit_to_test) +def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, + model_name, description) -> None: + + hf_model_kwargs = {"load_in_4bit": True} + validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], + model_name, hf_model_kwargs) + + +@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), + reason='bitsandbytes is not supported on this GPU type.') +@pytest.mark.parametrize("model_name, description", + models_pre_qaunt_4bit_to_test) +def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, + model_name, description) -> None: + + validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], + model_name) + @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), reason='bitsandbytes is not supported on this GPU type.') -@pytest.mark.parametrize("model_name, description", models_to_test) -def test_load_bnb_model(vllm_runner, model_name, description) -> None: +@pytest.mark.parametrize("model_name, description", + models_pre_quant_8bit_to_test) +def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts, + model_name, description) -> None: + + validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], + model_name) + + +def log_generated_texts(prompts, outputs, runner_name): + logged_texts = [] + for i, (_, generated_text) in enumerate(outputs): + log_entry = { + "prompt": prompts[i], + "runner_name": runner_name, + "generated_text": generated_text, + } + logged_texts.append(log_entry) + return logged_texts + + +def validate_generated_texts(hf_runner, + vllm_runner, + prompts, + model_name, + hf_model_kwargs=None): + + if hf_model_kwargs is None: + hf_model_kwargs = {} + + # Run with HF runner + with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm: + hf_outputs = llm.generate_greedy(prompts, 8) + hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner") + + # Clean up the GPU memory for the next test + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + + #Run with vLLM runner with vllm_runner(model_name, quantization='bitsandbytes', load_format='bitsandbytes', - enforce_eager=True) as llm: - model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 - - # check the weights in MLP & SelfAttention are quantized to torch.uint8 - qweight = model.model.layers[0].mlp.gate_up_proj.qweight - assert qweight.dtype == torch.uint8, ( - f'Expected gate_up_proj dtype torch.uint8 but got {qweight.dtype}') - - qweight = model.model.layers[0].mlp.down_proj.qweight - assert qweight.dtype == torch.uint8, ( - f'Expected down_proj dtype torch.uint8 but got {qweight.dtype}') - - qweight = model.model.layers[0].self_attn.o_proj.qweight - assert qweight.dtype == torch.uint8, ( - f'Expected o_proj dtype torch.uint8 but got {qweight.dtype}') - - qweight = model.model.layers[0].self_attn.qkv_proj.qweight - assert qweight.dtype == torch.uint8, ( - f'Expected qkv_proj dtype torch.uint8 but got {qweight.dtype}') - - # some weights should not be quantized - weight = model.lm_head.weight - assert weight.dtype != torch.uint8, ( - 'lm_head weight dtype should not be torch.uint8') - - weight = model.model.embed_tokens.weight - assert weight.dtype != torch.uint8, ( - 'embed_tokens weight dtype should not be torch.uint8') - - weight = model.model.layers[0].input_layernorm.weight - assert weight.dtype != torch.uint8, ( - 'input_layernorm weight dtype should not be torch.uint8') - - weight = model.model.layers[0].post_attention_layernorm.weight - assert weight.dtype != torch.uint8, ( - 'input_layernorm weight dtype should not be torch.uint8') - - # check the output of the model is expected - sampling_params = SamplingParams(temperature=0.0, - logprobs=1, - prompt_logprobs=1, - max_tokens=8) - - prompts = ['That which does not kill us', 'To be or not to be,'] - expected_outputs = [ - 'That which does not kill us makes us stronger.', - 'To be or not to be, that is the question.' - ] - outputs = llm.generate(prompts, sampling_params=sampling_params) - assert len(outputs) == len(prompts) - - for index in range(len(outputs)): - # compare the first line of the output - actual_output = outputs[index][1][0].split('\n', 1)[0] - expected_output = expected_outputs[index].split('\n', 1)[0] - - assert len(actual_output) >= len(expected_output), ( - f'Actual {actual_output} should be larger than or equal to ' - f'expected {expected_output}') - actual_output = actual_output[:len(expected_output)] - - assert actual_output == expected_output, ( - f'Expected: {expected_output}, but got: {actual_output}') + enforce_eager=True, + gpu_memory_utilization=0.8) as llm: + vllm_outputs = llm.generate_greedy(prompts, 8) + vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner") + + # Clean up the GPU memory for the next test + torch.cuda.synchronize() + gc.collect() + torch.cuda.empty_cache() + + # Compare the generated strings + for hf_log, vllm_log in zip(hf_logs, vllm_logs): + hf_str = hf_log["generated_text"] + vllm_str = vllm_log["generated_text"] + prompt = hf_log["prompt"] + assert hf_str == vllm_str, (f"Model: {model_name}" + f"Mismatch between HF and vLLM outputs:\n" + f"Prompt: {prompt}\n" + f"HF Output: '{hf_str}'\n" + f"vLLM Output: '{vllm_str}'") diff --git a/tests/samplers/test_rejection_sampler.py b/tests/samplers/test_rejection_sampler.py index 3ce4a5f658198..91a9d879eb4a5 100644 --- a/tests/samplers/test_rejection_sampler.py +++ b/tests/samplers/test_rejection_sampler.py @@ -44,12 +44,16 @@ def mock_causal_accepted_tensor( ["all_tokens_accepted", "no_tokens_accepted", "some_tokens_accepted"]) @pytest.mark.parametrize("disable_bonus_tokens", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_flashinfer", [True, False]) @torch.inference_mode() -def test_correct_output_format(which_tokens_accepted: str, - disable_bonus_tokens: bool, seed: int, - device: str): +def test_correct_output_format(which_tokens_accepted: str, seed: int, + disable_bonus_tokens: bool, device: str, + use_flashinfer: bool): """Verify the output has correct format given predetermined accepted matrix. """ + if use_flashinfer and disable_bonus_tokens: + pytest.skip("Flashinfer rejection sampler must enable bonus token.") + set_random_seed(seed) torch.set_default_device(device) @@ -85,7 +89,8 @@ def test_correct_output_format(which_tokens_accepted: str, dtype=torch.int64) rejection_sampler = RejectionSampler( - disable_bonus_tokens=disable_bonus_tokens) + disable_bonus_tokens=disable_bonus_tokens, + use_flashinfer=use_flashinfer) rejection_sampler.init_gpu_tensors(device=device) output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access accepted, @@ -133,15 +138,20 @@ def test_correct_output_format(which_tokens_accepted: str, @pytest.mark.parametrize("vocab_size", [30_000, 50_000]) @pytest.mark.parametrize("batch_size", list(range(1, 32))) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_flashinfer", [True, False]) @torch.inference_mode() def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, - device: str): + device: str, use_flashinfer: bool): torch.set_default_device(device) - rejection_sampler = RejectionSampler() + rejection_sampler = RejectionSampler(disable_bonus_tokens=False, + use_flashinfer=use_flashinfer) rejection_sampler.init_gpu_tensors(device=device) draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), @@ -161,16 +171,21 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, @pytest.mark.parametrize("batch_size", [1, 8, 32, 128]) @pytest.mark.parametrize("n_rep", [100]) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_flashinfer", [True, False]) @torch.inference_mode() def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int, - frac_seeded: float, n_rep: int, - device: str): + frac_seeded: float, n_rep: int, device: str, + use_flashinfer: bool): torch.set_default_device(device) - rejection_sampler = RejectionSampler() + rejection_sampler = RejectionSampler(disable_bonus_tokens=False, + use_flashinfer=use_flashinfer) rejection_sampler.init_gpu_tensors(device=device) draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), @@ -198,23 +213,85 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int, assert torch.equal(results[j][i], results[0][i]) +@pytest.mark.parametrize("k", [1, 3, 6]) +@pytest.mark.parametrize("vocab_size", [30_000, 50_000]) +@pytest.mark.parametrize("batch_size", [1, 8, 32, 128]) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_compare_nonflashinfer_backend(k: int, vocab_size: int, + batch_size: int, device: str): + """ + Test the flashinfer and nonflashinfer backend generate + the same output metrics. + """ + torch.set_default_device(device) + torch.manual_seed(0) + draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) + bonus_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, 1), + dtype=torch.int64) + draft_token_ids = torch.randint(low=0, + high=vocab_size, + size=(batch_size, k), + dtype=torch.int64) + + num_accepted_tokens = [] + num_emitted_tokens = [] + num_draft_tokens = [] + + def get_seeded_seqs(): + return { + i: torch.Generator(device=device).manual_seed(i) + for i in range(batch_size) + } + + for use_flashinfer in [True, False]: + rejection_sampler = RejectionSampler(disable_bonus_tokens=False, + use_flashinfer=use_flashinfer) + rejection_sampler.init_gpu_tensors(device=device) + # We use seeded sequences to ensure the same tokens are accepted + # for both flashinfer and nonflashinfer backends. + seeded_seqs = get_seeded_seqs() + rejection_sampler(target_probs, bonus_token_ids, draft_probs, + draft_token_ids, seeded_seqs) + num_accepted_tokens.append(rejection_sampler.num_accepted_tokens) + num_emitted_tokens.append(rejection_sampler.num_emitted_tokens) + num_draft_tokens.append(rejection_sampler.num_draft_tokens) + + assert num_accepted_tokens[0] == num_accepted_tokens[1] + assert num_emitted_tokens[0] == num_emitted_tokens[1] + assert num_draft_tokens[0] == num_draft_tokens[1] + + @pytest.mark.parametrize("above_or_below_vocab_range", ["above", "below"]) @pytest.mark.parametrize("which_token_ids", ["bonus_token_ids", "draft_token_ids"]) @pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("use_flashinfer", [True, False]) @torch.inference_mode() def test_raises_when_vocab_oob(above_or_below_vocab_range: str, - which_token_ids: str, device: str): + which_token_ids: str, device: str, + use_flashinfer: bool): k = 3 batch_size = 5 vocab_size = 30_000 torch.set_default_device(device) - rejection_sampler = RejectionSampler(strict_mode=True) + rejection_sampler = RejectionSampler(disable_bonus_tokens=False, + use_flashinfer=use_flashinfer, + strict_mode=True) rejection_sampler.init_gpu_tensors(device=device) draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) - target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), @@ -248,9 +325,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, @pytest.mark.parametrize("draft_and_target_probs_equal", [True, False]) @pytest.mark.parametrize("seed", list(range(5))) +@pytest.mark.parametrize("use_flashinfer", [True, False]) @torch.inference_mode() def test_rejection_sampling_approximates_target_distribution( - seed: int, draft_and_target_probs_equal: bool): + seed: int, draft_and_target_probs_equal: bool, use_flashinfer: bool): """Verify rejection sampling approximates target distribution, despite sampling from a potentially distinct draft distribution. @@ -279,10 +357,10 @@ def test_rejection_sampling_approximates_target_distribution( """ torch.set_default_device("cpu") set_random_seed(seed) - helper = _CorrectnessTestHelper( vocab_size=10, - rejection_sampler=RejectionSampler(), + rejection_sampler=RejectionSampler(disable_bonus_tokens=False, + use_flashinfer=use_flashinfer), ) draft_probs, target_probs, reference_probs = helper.generate_probs_for_test( @@ -398,10 +476,10 @@ def _estimate_rejection_sampling_pdf( draft_probs = draft_probs.reshape(1, self.k, self.vocab_size).repeat( num_samples, 1, 1) - # Repeat target probs num_samples * k times. + # Repeat target probs num_samples * (k + 1) times. # Rejection sampler requires bonus token probs, but they aren't used. target_probs = target_probs.reshape(1, 1, self.vocab_size).repeat( - num_samples, self.k, 1) + num_samples, self.k + 1, 1) # Randomly sample draft token ids from draft probs. draft_token_ids = torch.multinomial(draft_probs[:, 0, :], diff --git a/tests/samplers/test_typical_acceptance_sampler.py b/tests/samplers/test_typical_acceptance_sampler.py index aa3c1d29bdb36..e81ec4a0fdf1f 100644 --- a/tests/samplers/test_typical_acceptance_sampler.py +++ b/tests/samplers/test_typical_acceptance_sampler.py @@ -79,7 +79,10 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, torch.set_default_device(device) typical_acceptance_sampler = get_acceptance_sampler() typical_acceptance_sampler.init_gpu_tensors(device=device) - target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_with_bonus_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), @@ -89,7 +92,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int, size=(batch_size, k), dtype=torch.int64) # Verify that sampling succeeds for all cases. - typical_acceptance_sampler(target_probs, + typical_acceptance_sampler(target_with_bonus_probs, bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -112,7 +115,10 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, torch.set_default_device(device) typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True) typical_acceptance_sampler.init_gpu_tensors(device=device) - target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_with_bonus_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), @@ -141,7 +147,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str, oob_token_ids[0][0] = rogue_token_id with pytest.raises(AssertionError): - typical_acceptance_sampler(target_probs, + typical_acceptance_sampler(target_with_bonus_probs, bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -172,7 +178,10 @@ def test_uniform_target_distribution_accepts_all_tokens( typical_acceptance_sampler = get_acceptance_sampler( strict_mode=True, disable_bonus_tokens=disable_bonus_tokens) typical_acceptance_sampler.init_gpu_tensors(device=device) - target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) + target_with_bonus_probs = torch.rand(batch_size, + k + 1, + vocab_size, + dtype=torch.float32) draft_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, k), @@ -182,7 +191,7 @@ def test_uniform_target_distribution_accepts_all_tokens( size=(batch_size, 1), dtype=torch.int64) output_token_ids = typical_acceptance_sampler( - target_probs, + target_with_bonus_probs, bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -229,8 +238,9 @@ def test_temperature_zero_target_distribution(seed: int, # Simulate temperature 0 probability distribution for target probabilities # and create target probabilities such that only 1 token id has # probability 1.0 - target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist( - batch_size, k, vocab_size) + target_with_bonus_probs, zero_temperature_token_ids = \ + get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size) + zero_temperature_token_ids = zero_temperature_token_ids[:, :-1] # Populate draft_token_ids such that they exclude the token_ids # with probability = 1.0 draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, @@ -245,7 +255,7 @@ def test_temperature_zero_target_distribution(seed: int, # fallback to the greedy sampling for selecting 1 token for each sequence. # Verify the same. output_token_ids = typical_acceptance_sampler( - target_probs, + target_with_bonus_probs, bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -289,8 +299,10 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, # For sequences 0 and 2 set the distribution to a temperature # zero distribution. For sequences 1 and 3 set it to a uniform # distribution. - target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( - batch_size, k, vocab_size)) + target_with_bonus_probs, zero_temperature_token_ids = \ + get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size) + zero_temperature_token_ids = zero_temperature_token_ids[:, :-1] + target_probs = target_with_bonus_probs[:, :-1] draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, zero_temperature_token_ids) uniform_probs = torch.rand(2, k, vocab_size, dtype=torch.float32) @@ -300,7 +312,7 @@ def test_mixed_target_distribution(seed: int, disable_bonus_tokens: bool, size=(batch_size, 1), dtype=torch.int64) output_token_ids = typical_acceptance_sampler( - target_probs, + target_with_bonus_probs, bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -356,15 +368,16 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, # Create a temperature zero target probability distribution and ensure # all draft token ids correspond to the tokens with 1.0 probability. # Verify that all of them are accepted. - target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( - batch_size, k, vocab_size)) + target_with_bonus_probs, zero_temperature_token_ids = \ + get_zero_temperature_prob_dist(batch_size, k + 1, vocab_size) + zero_temperature_token_ids = zero_temperature_token_ids[:, :-1] draft_token_ids = zero_temperature_token_ids bonus_token_ids = torch.randint(low=0, high=vocab_size, size=(batch_size, 1), dtype=torch.int64) output_token_ids = typical_acceptance_sampler( - target_probs, + target_with_bonus_probs, bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -384,7 +397,7 @@ def test_accept_tokens_partially(seed: int, disable_bonus_tokens: bool, draft_token_ids = torch.cat( (draft_token_ids[:, :2], draft_token_ids_to_replace[:, -3:]), dim=1) output_token_ids = typical_acceptance_sampler( - target_probs, + target_with_bonus_probs, bonus_token_ids, draft_probs=None, draft_token_ids=draft_token_ids) @@ -421,8 +434,9 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, # 0.00001. Populate draft_token_ids such that they exclude the token_ids # with probability = 1.0. Without any changes to the posterior thresholds # none of the draft tokens are accepted. - target_probs, zero_temperature_token_ids = (get_zero_temperature_prob_dist( - batch_size, k, vocab_size)) + target_probs, zero_temperature_token_ids = get_zero_temperature_prob_dist( + batch_size, k + 1, vocab_size) + zero_temperature_token_ids = zero_temperature_token_ids[:, :-1] target_probs[target_probs == 0] = 0.00001 draft_token_ids = get_draft_token_ids(batch_size, k, vocab_size, zero_temperature_token_ids) diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index ada6c37d9af8d..e7a0af4377630 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -5,9 +5,10 @@ import pytest import torch +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.utils import set_random_seed from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob, - SamplerOutput, get_all_seq_ids) + get_all_seq_ids) from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.top1_proposer import Top1Proposer diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 9ae1b4bc40f0f..501d05756e01c 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -7,8 +7,9 @@ import pytest import torch +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.utils import set_random_seed -from vllm.sequence import ExecuteModelRequest, SamplerOutput, SequenceOutput +from vllm.sequence import ExecuteModelRequest, SequenceOutput from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.metrics import (AsyncMetricsCollector, SpecDecodeWorkerMetrics) @@ -229,9 +230,8 @@ def test_correctly_calls_spec_decode_sampler(k: int, batch_size: int, assert torch.equal(actual.bonus_token_ids, target_token_ids.reshape(batch_size, k + 1)[:, -1:]) - assert torch.equal( - actual.target_probs, - target_token_probs.reshape(batch_size, k + 1, -1)[:, :-1]) + assert torch.equal(actual.target_with_bonus_probs, + target_token_probs.reshape(batch_size, k + 1, -1)) assert torch.equal(actual.draft_token_ids, proposal_token_ids) assert torch.equal(actual.draft_probs, proposal_probs) diff --git a/tests/spec_decode/test_utils.py b/tests/spec_decode/test_utils.py index 06780d4b8cd01..195fce64822bd 100644 --- a/tests/spec_decode/test_utils.py +++ b/tests/spec_decode/test_utils.py @@ -4,10 +4,12 @@ import torch from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.model_executor.layers.sampler import _get_ranks from vllm.model_executor.layers.typical_acceptance_sampler import ( TypicalAcceptanceSampler) from vllm.sequence import SequenceGroupMetadata, get_all_seq_ids -from vllm.spec_decode.util import split_batch_by_proposal_len +from vllm.spec_decode.util import (get_sampled_token_logprobs, + split_batch_by_proposal_len) def test_get_all_seq_ids(): @@ -126,3 +128,20 @@ def mock_spec_decode_sampler(acceptance_sampler_method): return sampler else: raise ValueError(f"Invalid sampler name {acceptance_sampler_method}") + + +def test_get_sampled_token_logprobs(): + """Verify get_sampled_token_logprobs returns consistent rankings + with regular get_ranks when probabilities match exactly. + """ + logprob_tensor = torch.tensor( + [[[-.1, -.1]] * 2]) # shape (num_steps, batch_size, vocab_size) + sampled_token_tensor = torch.tensor([[1, + 0]]) # shape (num_steps, batch_size) + ranks_spec_dec, _ = get_sampled_token_logprobs(logprob_tensor, + sampled_token_tensor) + + ranks_regular = _get_ranks(logprob_tensor.reshape((2, -1)), + sampled_token_tensor.reshape(-1)) + + assert torch.equal(ranks_spec_dec.reshape(-1), ranks_regular) diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index 60b36a33d9077..9075a433eb66e 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -8,12 +8,12 @@ import torch from vllm.engine.arg_utils import EngineArgs +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.utils import set_random_seed from vllm.sampling_params import SamplingParams from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, CompletionSequenceGroupOutput, Logprob, - SamplerOutput, SequenceData, SequenceGroupMetadata, - SequenceOutput) + SequenceData, SequenceGroupMetadata, SequenceOutput) from vllm.utils import get_distributed_init_method, get_ip, get_open_port from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner diff --git a/tests/test_logger.py b/tests/test_logger.py index 29346cd0878b8..8f3d218416870 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -95,7 +95,7 @@ def test_logger_configuring_can_be_disabled(): config behavior, however mocks are used to ensure no changes in behavior or configuration occur.""" - with patch("logging.config.dictConfig") as dict_config_mock: + with patch("vllm.logger.dictConfig") as dict_config_mock: _configure_vllm_root_logger() dict_config_mock.assert_not_called() @@ -175,9 +175,9 @@ def test_custom_logging_config_is_parsed_and_used_when_provided(): logging_config_file.flush() with patch("vllm.logger.VLLM_LOGGING_CONFIG_PATH", logging_config_file.name), patch( - "logging.config.dictConfig") as dict_config_mock: + "vllm.logger.dictConfig") as dict_config_mock: _configure_vllm_root_logger() - assert dict_config_mock.called_with(valid_logging_config) + dict_config_mock.assert_called_with(valid_logging_config) @patch("vllm.logger.VLLM_CONFIGURE_LOGGING", 0) diff --git a/tests/test_sequence.py b/tests/test_sequence.py index 1ae349e808e0d..348ba7dd41d99 100644 --- a/tests/test_sequence.py +++ b/tests/test_sequence.py @@ -2,9 +2,10 @@ import pytest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, - CompletionSequenceGroupOutput, SamplerOutput, - SequenceData, SequenceOutput) + CompletionSequenceGroupOutput, SequenceData, + SequenceOutput) from .core.utils import create_dummy_prompt diff --git a/tests/test_utils.py b/tests/test_utils.py index c157be1c08f81..c7cb663068c0f 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -132,6 +132,16 @@ def parser(): return parser +@pytest.fixture +def parser_with_config(): + parser = FlexibleArgumentParser() + parser.add_argument('serve') + parser.add_argument('--config', type=str) + parser.add_argument('--port', type=int) + parser.add_argument('--tensor-parallel-size', type=int) + return parser + + def test_underscore_to_dash(parser): args = parser.parse_args(['--image_input_type', 'pixel_values']) assert args.image_input_type == 'pixel_values' @@ -176,3 +186,37 @@ def test_missing_required_argument(parser): parser.add_argument('--required-arg', required=True) with pytest.raises(SystemExit): parser.parse_args([]) + + +def test_cli_override_to_config(parser_with_config): + args = parser_with_config.parse_args([ + 'serve', '--config', './data/test_config.yaml', + '--tensor-parallel-size', '3' + ]) + assert args.tensor_parallel_size == 3 + args = parser_with_config.parse_args([ + 'serve', '--tensor-parallel-size', '3', '--config', + './data/test_config.yaml' + ]) + assert args.tensor_parallel_size == 3 + + +def test_config_args(parser_with_config): + args = parser_with_config.parse_args( + ['serve', '--config', './data/test_config.yaml']) + assert args.tensor_parallel_size == 2 + + +def test_config_file(parser_with_config): + with pytest.raises(FileNotFoundError): + parser_with_config.parse_args(['serve', '--config', 'test_config.yml']) + + with pytest.raises(ValueError): + parser_with_config.parse_args( + ['serve', '--config', './data/test_config.json']) + + with pytest.raises(ValueError): + parser_with_config.parse_args([ + 'serve', '--tensor-parallel-size', '3', '--config', '--batch-size', + '32' + ]) diff --git a/tests/tool_use/__init__.py b/tests/tool_use/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tool_use/conftest.py b/tests/tool_use/conftest.py new file mode 100644 index 0000000000000..ab6a29eba1b3f --- /dev/null +++ b/tests/tool_use/conftest.py @@ -0,0 +1,32 @@ +import pytest +import pytest_asyncio +from huggingface_hub import snapshot_download + +from tests.utils import RemoteOpenAIServer + +from .utils import ARGS, CONFIGS, ServerConfig + + +# for each server config, download the model and return the config +@pytest.fixture(scope="session", params=CONFIGS.keys()) +def server_config(request): + config = CONFIGS[request.param] + # download model and tokenizer using transformers + snapshot_download(config["model"]) + yield CONFIGS[request.param] + + +# run this for each server config +@pytest.fixture(scope="session") +def server(request, server_config: ServerConfig): + model = server_config["model"] + args_for_model = server_config["arguments"] + with RemoteOpenAIServer(model, ARGS + args_for_model, + max_wait_seconds=480) as server: + yield server + + +@pytest_asyncio.fixture +async def client(server: RemoteOpenAIServer): + async with server.get_async_client() as async_client: + yield async_client diff --git a/tests/tool_use/test_chat_completions.py b/tests/tool_use/test_chat_completions.py new file mode 100644 index 0000000000000..038ff81d2b674 --- /dev/null +++ b/tests/tool_use/test_chat_completions.py @@ -0,0 +1,143 @@ +from typing import List + +import openai +import pytest + +from .utils import MESSAGES_WITHOUT_TOOLS, WEATHER_TOOL + + +# test: make sure chat completions without tools provided work even when tools +# are enabled. This makes sure tool call chat templates work, AND that the tool +# parser stream processing doesn't change the output of the model. +@pytest.mark.asyncio +async def test_chat_completion_without_tools(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=150, + model=model_name, + logprobs=False) + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + output_text = chat_completion.choices[0].message.content + + # check to make sure we got text + assert output_text is not None + assert len(output_text) > 0 + assert stop_reason != "tool_calls" + + # check to make sure no tool calls were returned + assert (choice.message.tool_calls is None + or len(choice.message.tool_calls) == 0) + + # make the same request, streaming + stream = await client.chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=150, + model=model_name, + logprobs=False, + stream=True, + ) + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + # assemble streamed chunks + async for chunk in stream: + delta = chunk.choices[0].delta + + # make sure the role is assistant + if delta.role: + assert not role_sent + assert delta.role == 'assistant' + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason + + # make sure tool call chunks aren't being streamed + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + # make sure the role was sent, only 1 finish reason was sent, that chunks + # were in fact sent, and that the chunks match non-streaming + assert role_sent + assert finish_reason_count == 1 + assert len(chunks) + assert "".join(chunks) == output_text + + +# test: conversation with tools enabled and provided that should not invoke +# tools, to make sure we can still get normal chat completion responses +# and that they won't be parsed as tools +@pytest.mark.asyncio +async def test_chat_completion_with_tools(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=150, + model=model_name, + tools=[WEATHER_TOOL], + logprobs=False) + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + output_text = chat_completion.choices[0].message.content + + # check to make sure we got text + assert output_text is not None + assert stop_reason != 'tool_calls' + assert len(output_text) > 0 + + # check to make sure no tool calls were returned + assert (choice.message.tool_calls is None + or len(choice.message.tool_calls) == 0) + + # make the same request, streaming + stream = await client.chat.completions.create( + messages=MESSAGES_WITHOUT_TOOLS, + temperature=0, + max_tokens=150, + model=model_name, + logprobs=False, + tools=[WEATHER_TOOL], + stream=True, + ) + + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + # assemble streamed chunks + async for chunk in stream: + delta = chunk.choices[0].delta + + # make sure the role is assistant + if delta.role: + assert delta.role == 'assistant' + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + + # make sure tool call chunks aren't being streamed + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + # make sure the role was sent, only 1 finish reason was sent, that chunks + # were in fact sent, and that the chunks match non-streaming + assert role_sent + assert finish_reason_count == 1 + assert chunk.choices[0].finish_reason == stop_reason + assert chunk.choices[0].finish_reason != 'tool_calls' + assert len(chunks) + assert "".join(chunks) == output_text diff --git a/tests/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py new file mode 100644 index 0000000000000..b03b5a2075a6c --- /dev/null +++ b/tests/tool_use/test_parallel_tool_calls.py @@ -0,0 +1,193 @@ +import json +from typing import Dict, List, Optional + +import openai +import pytest + +from .utils import (MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, SEARCH_TOOL, + WEATHER_TOOL) + + +# test: getting the model to generate parallel tool calls (streaming/not) +# when requested. NOTE that not all models may support this, so some exclusions +# may be added in the future. e.g. llama 3.1 models are not designed to support +# parallel tool calls. +@pytest.mark.asyncio +async def test_parallel_tool_calls(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + temperature=0, + max_tokens=200, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + non_streamed_tool_calls = chat_completion.choices[0].message.tool_calls + + # make sure 2 tool calls are present + assert choice.message.role == "assistant" + assert non_streamed_tool_calls is not None + assert len(non_streamed_tool_calls) == 2 + + for tool_call in non_streamed_tool_calls: + # make sure the tool includes a function and ID + assert tool_call.type == "function" + assert tool_call.function is not None + assert isinstance(tool_call.id, str) + assert len(tool_call.id) > 16 + + # make sure the weather tool was called correctly + assert tool_call.function.name == WEATHER_TOOL["function"]["name"] + assert isinstance(tool_call.function.arguments, str) + + parsed_arguments = json.loads(tool_call.function.arguments) + assert isinstance(parsed_arguments, Dict) + assert isinstance(parsed_arguments.get("city"), str) + assert isinstance(parsed_arguments.get("state"), str) + + assert stop_reason == "tool_calls" + + # make the same request, streaming + stream = await client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_PARALLEL_TOOLS, + temperature=0, + max_tokens=200, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + role_name: Optional[str] = None + finish_reason_count: int = 0 + + tool_call_names: List[str] = [] + tool_call_args: List[str] = [] + tool_call_idx: int = -1 + tool_call_id_count: int = 0 + + async for chunk in stream: + + # if there's a finish reason make sure it's tools + if chunk.choices[0].finish_reason: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == 'tool_calls' + + # if a role is being streamed make sure it wasn't already set to + # something else + if chunk.choices[0].delta.role: + assert not role_name or role_name == 'assistant' + role_name = 'assistant' + + # if a tool call is streamed make sure there's exactly one + # (based on the request parameters + streamed_tool_calls = chunk.choices[0].delta.tool_calls + + if streamed_tool_calls and len(streamed_tool_calls) > 0: + + # make sure only one diff is present - correct even for parallel + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + # if a new tool is being called, set up empty arguments + if tool_call.index != tool_call_idx: + tool_call_idx = tool_call.index + tool_call_args.append("") + + # if a tool call ID is streamed, make sure one hasn't been already + if tool_call.id: + tool_call_id_count += 1 + assert (isinstance(tool_call.id, str) + and (len(tool_call.id) > 16)) + + # if parts of the function start being streamed + if tool_call.function: + # if the function name is defined, set it. it should be streamed + # IN ENTIRETY, exactly one time. + if tool_call.function.name: + assert isinstance(tool_call.function.name, str) + tool_call_names.append(tool_call.function.name) + + if tool_call.function.arguments: + # make sure they're a string and then add them to the list + assert isinstance(tool_call.function.arguments, str) + + tool_call_args[ + tool_call.index] += tool_call.function.arguments + + assert finish_reason_count == 1 + assert role_name == 'assistant' + + assert (len(non_streamed_tool_calls) == len(tool_call_names) == + len(tool_call_args)) + + for i in range(2): + assert non_streamed_tool_calls[i].function.name == tool_call_names[i] + streamed_args = json.loads(tool_call_args[i]) + non_streamed_args = json.loads( + non_streamed_tool_calls[i].function.arguments) + assert streamed_args == non_streamed_args + + +# test: providing parallel tool calls back to the model to get a response +# (streaming/not) +@pytest.mark.asyncio +async def test_parallel_tool_calls_with_results(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, + temperature=0, + max_tokens=200, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + + assert choice.finish_reason != "tool_calls" # "stop" or "length" + assert choice.message.role == "assistant" + assert choice.message.tool_calls is None \ + or len(choice.message.tool_calls) == 0 + assert choice.message.content is not None + assert "98" in choice.message.content # Dallas temp in tool response + assert "78" in choice.message.content # Orlando temp in tool response + + stream = await client.chat.completions.create( + messages=MESSAGES_WITH_PARALLEL_TOOL_RESPONSE, + temperature=0, + max_tokens=200, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + async for chunk in stream: + delta = chunk.choices[0].delta + + if delta.role: + assert not role_sent + assert delta.role == "assistant" + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason + + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + assert role_sent + assert finish_reason_count == 1 + assert len(chunks) + assert "".join(chunks) == choice.message.content diff --git a/tests/tool_use/test_tool_calls.py b/tests/tool_use/test_tool_calls.py new file mode 100644 index 0000000000000..c3abe9e1f5060 --- /dev/null +++ b/tests/tool_use/test_tool_calls.py @@ -0,0 +1,192 @@ +import json +from typing import Dict, List, Optional + +import openai +import pytest + +from .utils import (MESSAGES_ASKING_FOR_TOOLS, MESSAGES_WITH_TOOL_RESPONSE, + SEARCH_TOOL, WEATHER_TOOL) + + +# test: request a chat completion that should return tool calls, so we know they +# are parsable +@pytest.mark.asyncio +async def test_tool_call_and_choice(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_ASKING_FOR_TOOLS, + temperature=0, + max_tokens=100, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + stop_reason = chat_completion.choices[0].finish_reason + tool_calls = chat_completion.choices[0].message.tool_calls + + # make sure a tool call is present + assert choice.message.role == 'assistant' + assert tool_calls is not None + assert len(tool_calls) == 1 + assert tool_calls[0].type == 'function' + assert tool_calls[0].function is not None + assert isinstance(tool_calls[0].id, str) + assert len(tool_calls[0].id) > 16 + + # make sure the weather tool was called (classic example) with arguments + assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"] + assert tool_calls[0].function.arguments is not None + assert isinstance(tool_calls[0].function.arguments, str) + + # make sure the arguments parse properly + parsed_arguments = json.loads(tool_calls[0].function.arguments) + assert isinstance(parsed_arguments, Dict) + assert isinstance(parsed_arguments.get("city"), str) + assert isinstance(parsed_arguments.get("state"), str) + assert parsed_arguments.get("city") == "Dallas" + assert parsed_arguments.get("state") == "TX" + + assert stop_reason == "tool_calls" + + function_name: Optional[str] = None + function_args_str: str = '' + tool_call_id: Optional[str] = None + role_name: Optional[str] = None + finish_reason_count: int = 0 + + # make the same request, streaming + stream = await client.chat.completions.create( + model=model_name, + messages=MESSAGES_ASKING_FOR_TOOLS, + temperature=0, + max_tokens=100, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + async for chunk in stream: + assert chunk.choices[0].index == 0 + + if chunk.choices[0].finish_reason: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == 'tool_calls' + + # if a role is being streamed make sure it wasn't already set to + # something else + if chunk.choices[0].delta.role: + assert not role_name or role_name == 'assistant' + role_name = 'assistant' + + # if a tool call is streamed make sure there's exactly one + # (based on the request parameters + streamed_tool_calls = chunk.choices[0].delta.tool_calls + + if streamed_tool_calls and len(streamed_tool_calls) > 0: + assert len(streamed_tool_calls) == 1 + tool_call = streamed_tool_calls[0] + + # if a tool call ID is streamed, make sure one hasn't been already + if tool_call.id: + assert not tool_call_id + tool_call_id = tool_call.id + + # if parts of the function start being streamed + if tool_call.function: + # if the function name is defined, set it. it should be streamed + # IN ENTIRETY, exactly one time. + if tool_call.function.name: + assert function_name is None + assert isinstance(tool_call.function.name, str) + function_name = tool_call.function.name + if tool_call.function.arguments: + assert isinstance(tool_call.function.arguments, str) + function_args_str += tool_call.function.arguments + + assert finish_reason_count == 1 + assert role_name == 'assistant' + assert isinstance(tool_call_id, str) and (len(tool_call_id) > 16) + + # validate the name and arguments + assert function_name == WEATHER_TOOL["function"]["name"] + assert function_name == tool_calls[0].function.name + assert isinstance(function_args_str, str) + + # validate arguments + streamed_args = json.loads(function_args_str) + assert isinstance(streamed_args, Dict) + assert isinstance(streamed_args.get("city"), str) + assert isinstance(streamed_args.get("state"), str) + assert streamed_args.get("city") == "Dallas" + assert streamed_args.get("state") == "TX" + + # make sure everything matches non-streaming except for ID + assert function_name == tool_calls[0].function.name + assert choice.message.role == role_name + assert choice.message.tool_calls[0].function.name == function_name + + # compare streamed with non-streamed args Dict-wise, not string-wise + # because character-to-character comparison might not work e.g. the tool + # call parser adding extra spaces or something like that. we care about the + # dicts matching not byte-wise match + assert parsed_arguments == streamed_args + + +# test: providing tools and results back to model to get a non-tool response +# (streaming/not) +@pytest.mark.asyncio +async def test_tool_call_with_results(client: openai.AsyncOpenAI): + models = await client.models.list() + model_name: str = models.data[0].id + chat_completion = await client.chat.completions.create( + messages=MESSAGES_WITH_TOOL_RESPONSE, + temperature=0, + max_tokens=100, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False) + + choice = chat_completion.choices[0] + + assert choice.finish_reason != "tool_calls" # "stop" or "length" + assert choice.message.role == "assistant" + assert choice.message.tool_calls is None \ + or len(choice.message.tool_calls) == 0 + assert choice.message.content is not None + assert "98" in choice.message.content # the temperature from the response + + stream = await client.chat.completions.create( + messages=MESSAGES_WITH_TOOL_RESPONSE, + temperature=0, + max_tokens=100, + model=model_name, + tools=[WEATHER_TOOL, SEARCH_TOOL], + logprobs=False, + stream=True) + + chunks: List[str] = [] + finish_reason_count = 0 + role_sent: bool = False + + async for chunk in stream: + delta = chunk.choices[0].delta + + if delta.role: + assert not role_sent + assert delta.role == "assistant" + role_sent = True + + if delta.content: + chunks.append(delta.content) + + if chunk.choices[0].finish_reason is not None: + finish_reason_count += 1 + assert chunk.choices[0].finish_reason == choice.finish_reason + + assert not delta.tool_calls or len(delta.tool_calls) == 0 + + assert role_sent + assert finish_reason_count == 1 + assert len(chunks) + assert "".join(chunks) == choice.message.content diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py new file mode 100644 index 0000000000000..e447469e33410 --- /dev/null +++ b/tests/tool_use/utils.py @@ -0,0 +1,215 @@ +from typing import Dict, List + +from openai.types.chat import (ChatCompletionMessageParam, + ChatCompletionToolParam) +from typing_extensions import TypedDict + +from tests.utils import VLLM_PATH + + +class ServerConfig(TypedDict): + model: str + arguments: List[str] + + +# universal args for all models go here. also good if you need to test locally +# and change type or KV cache quantization or something. +ARGS: List[str] = ["--enable-auto-tool-choice", "--max-model-len", "8096"] + +CONFIGS: Dict[str, ServerConfig] = { + "hermes": { + "model": + "NousResearch/Hermes-3-Llama-3.1-8B", + "arguments": [ + "--tool-call-parser", "hermes", "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_hermes.jinja") + ] + }, + "mistral": { + "model": + "mistralai/Mistral-7B-Instruct-v0.3", + "arguments": [ + "--tool-call-parser", "mistral", "--chat-template", + str(VLLM_PATH / "examples/tool_chat_template_mistral.jinja"), + "--ignore-patterns=\"consolidated.safetensors\"" + ] + } +} + +WEATHER_TOOL: ChatCompletionToolParam = { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": + "string", + "description": + "The city to find the weather for, " + "e.g. 'San Francisco'" + }, + "state": { + "type": + "string", + "description": + "the two-letter abbreviation for the state " + "that the city is in, e.g. 'CA' which would " + "mean 'California'" + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"] + } + } + } + } +} + +SEARCH_TOOL: ChatCompletionToolParam = { + "type": "function", + "function": { + "name": + "web_search", + "description": + "Search the internet and get a summary of the top " + "10 webpages. Should only be used if you don't know " + "the answer to a user query, and the results are likely" + "to be able to be found with a web search", + "parameters": { + "type": "object", + "properties": { + "search_term": { + "type": + "string", + "description": + "The term to use in the search. This should" + "ideally be keywords to search for, not a" + "natural-language question" + } + }, + "required": ["search_term"] + } + } +} + +MESSAGES_WITHOUT_TOOLS: List[ChatCompletionMessageParam] = [{ + "role": + "system", + "content": + "You are a helpful assistant with access to tools. If a tool" + " that you have would be helpful to answer a user query, " + "call the tool. Otherwise, answer the user's query directly " + "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " + "to the user's question - just respond to it normally." +}, { + "role": + "user", + "content": + "Hi! How are you?" +}, { + "role": + "assistant", + "content": + "I'm doing great! How can I assist you?" +}, { + "role": + "user", + "content": + "Can you tell me a joke please?" +}] + +MESSAGES_ASKING_FOR_TOOLS: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas in Fahrenheit?" +}] + +MESSAGES_WITH_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas in Fahrenheit?" +}, { + "role": + "assistant", + "tool_calls": [{ + "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "type": "function", + "function": { + "name": + WEATHER_TOOL["function"]["name"], + "arguments": + '{"city": "Dallas", "state": "TX", ' + '"unit": "fahrenheit"}' + } + }] +}, { + "role": + "tool", + "tool_call_id": + "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "content": + "The weather in Dallas is 98 degrees fahrenheit, with partly" + "cloudy skies and a low chance of rain." +}] + +MESSAGES_ASKING_FOR_PARALLEL_TOOLS: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas and Orlando, Florida in " + "Fahrenheit?" +}] + +MESSAGES_WITH_PARALLEL_TOOL_RESPONSE: List[ChatCompletionMessageParam] = [{ + "role": + "user", + "content": + "What is the weather in Dallas, Texas and Orlando, Florida in " + "Fahrenheit?" +}, { + "role": + "assistant", + "tool_calls": [{ + "id": "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "type": "function", + "function": { + "name": + WEATHER_TOOL["function"]["name"], + "arguments": + '{"city": "Dallas", "state": "TX", ' + '"unit": "fahrenheit"}' + } + }, { + "id": "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", + "type": "function", + "function": { + "name": + WEATHER_TOOL["function"]["name"], + "arguments": + '{"city": "Orlando", "state": "Fl", ' + '"unit": "fahrenheit"}' + } + }] +}, { + "role": + "tool", + "tool_call_id": + "chatcmpl-tool-03e6481b146e408e9523d9c956696295", + "content": + "The weather in Dallas TX is 98 degrees fahrenheit with mostly " + "cloudy skies and a chance of rain in the evening." +}, { + "role": + "tool", + "tool_call_id": + "chatcmpl-tool-d027061e1bd21cda48bee7da829c1f5b", + "content": + "The weather in Orlando FL is 78 degrees fahrenheit with clear" + "skies." +}] diff --git a/tests/tpu/__init__.py b/tests/tpu/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index 5a432fb78b3da..d8df86b2aaa14 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -5,6 +5,10 @@ import depyf +# disable custom dispatcher, let Dynamo takes over +# all the control +os.environ['VLLM_DYNAMO_USE_CUSTOM_DISPATCHER'] = "0" + temp_dir = tempfile.mkdtemp() with depyf.prepare_debug(temp_dir): cur_dir = os.path.dirname(__file__) @@ -16,19 +20,36 @@ compiled_code = sorted( glob.glob(os.path.join(temp_dir, "__transformed_code*.py"))) -full_code = glob.glob(os.path.join(temp_dir, "full_code*.py"))[0] + # we should only trigger Dynamo compilation three times: -# one for the profiling phase (and the compiled artifact will be discarded) +# one for the profiling phase without kv cache # one for the prefill phase with symbolic shapes # one for the decode phase with symbolic shapes # and later calls should not trigger Dynamo compilation again. # NOTE: it might still trigger XLA compilation. # check we have three compiled code +# this is the assumption when we use the custom dispatcher assert len(compiled_code) == 3 -# check the first compilation is discarded -with open(full_code) as f: - full_code_content = f.read() - profile_function = compiled_code[0].split(".")[0] - assert profile_function not in full_code_content +# check all the compilations are as expected +compiled_fn = sorted( + glob.glob(os.path.join(temp_dir, "__compiled_fn*Captured*.py"))) + +# the first compilation is the profiling phase, +# it should not have any kv cache +with open(compiled_fn[0]) as f: + content = f.read() + assert "kv_caches" not in content + +# the second compilation is the prefill phase, +# it should have kv cache and the flash_attention op +with open(compiled_fn[1]) as f: + content = f.read() + assert "kv_caches" in content and "torch.ops.xla.flash_attention" in content + +# the third compilation is the decode phase, +# it should have kv cache and the paged_attention op +with open(compiled_fn[2]) as f: + content = f.read() + assert "kv_caches" in content and "torch.ops.xla.paged_attention" in content diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py new file mode 100644 index 0000000000000..7f3fb595321ad --- /dev/null +++ b/tests/tpu/test_custom_dispatcher.py @@ -0,0 +1,9 @@ +from ..utils import compare_two_settings + + +def test_custom_dispatcher(): + compare_two_settings("google/gemma-2b", + arg1=["--enforce-eager"], + arg2=["--enforce-eager"], + env1={"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": "0"}, + env2={}) diff --git a/tests/utils.py b/tests/utils.py index de887bc8cf6fb..6e5bc05b3901a 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,14 +11,16 @@ import openai import requests +from openai.types.completion import Completion from transformers import AutoTokenizer from typing_extensions import ParamSpec +from tests.models.utils import TextTextLogprobs from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.engine.arg_utils import AsyncEngineArgs from vllm.entrypoints.openai.cli_args import make_arg_parser -from vllm.model_executor.model_loader.loader import DefaultModelLoader +from vllm.model_executor.model_loader.loader import get_model_loader from vllm.platforms import current_platform from vllm.utils import FlexibleArgumentParser, get_open_port, is_hip @@ -87,11 +89,11 @@ def __init__(self, is_local = os.path.isdir(model) if not is_local: engine_args = AsyncEngineArgs.from_cli_args(args) - engine_config = engine_args.create_engine_config() - dummy_loader = DefaultModelLoader(engine_config.load_config) - dummy_loader._prepare_weights(engine_config.model_config.model, - engine_config.model_config.revision, - fall_back_to_pt=True) + model_config = engine_args.create_model_config() + load_config = engine_args.create_load_config() + + model_loader = get_model_loader(load_config) + model_loader.download_model(model_config) env = os.environ.copy() # the current process might initialize cuda, @@ -176,7 +178,12 @@ def compare_two_settings(model: str, env2: The second set of environment variables to pass to the API server. """ - tokenizer = AutoTokenizer.from_pretrained(model) + trust_remote_code = "--trust-remote-code" + if trust_remote_code in arg1 or trust_remote_code in arg2: + tokenizer = AutoTokenizer.from_pretrained(model, + trust_remote_code=True) + else: + tokenizer = AutoTokenizer.from_pretrained(model) prompt = "Hello, my name is" token_ids = tokenizer(prompt)["input_ids"] @@ -432,3 +439,61 @@ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: f" args {args} and kwargs {kwargs}") return wrapper + + +async def completions_with_server_args( + prompts: List[str], + model_name: str, + server_cli_args: List[str], + num_logprobs: Optional[int], + max_wait_seconds: int = 240, +) -> Completion: + '''Construct a remote OpenAI server, obtain an async client to the + server & invoke the completions API to obtain completions. + + Args: + prompts: test prompts + model_name: model to spin up on the vLLM server + server_cli_args: CLI args for starting the server + num_logprobs: Number of logprobs to report (or `None`) + max_wait_seconds: timeout interval for bringing up server. + Default: 240sec + + Returns: + OpenAI Completion instance + ''' + + outputs = None + with RemoteOpenAIServer(model_name, + server_cli_args, + max_wait_seconds=max_wait_seconds) as server: + client = server.get_async_client() + outputs = await client.completions.create(model=model_name, + prompt=prompts, + temperature=0, + stream=False, + max_tokens=5, + logprobs=num_logprobs) + assert outputs is not None + + return outputs + + +def get_client_text_generations(completions: Completion) -> List[str]: + '''Extract generated tokens from the output of a + request made to an Open-AI-protocol completions endpoint. + ''' + return [x.text for x in completions.choices] + + +def get_client_text_logprob_generations( + completions: Completion) -> List[TextTextLogprobs]: + '''Operates on the output of a request made to an Open-AI-protocol + completions endpoint; obtains top-rank logprobs for each token in + each :class:`SequenceGroup` + ''' + text_generations = get_client_text_generations(completions) + text = ''.join(text_generations) + return [(text_generations, text, + (None if x.logprobs is None else x.logprobs.top_logprobs)) + for x in completions.choices] diff --git a/tests/weight_loading/models-large.txt b/tests/weight_loading/models-large.txt new file mode 100644 index 0000000000000..fe76705746766 --- /dev/null +++ b/tests/weight_loading/models-large.txt @@ -0,0 +1,3 @@ +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main +gptq_marlin, TheBloke/Mixtral-8x7B-v0.1-GPTQ, main \ No newline at end of file diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index cbe30305c14f6..a90b352a39bca 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -4,6 +4,12 @@ gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True gptq_marlin, TechxGenus/gemma-1.1-2b-it-GPTQ, main +gptq, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main +gptq, TheBloke/Llama-2-7B-GPTQ, main +gptq, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main +gptq, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True +gptq, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True +gptq, TechxGenus/gemma-1.1-2b-it-GPTQ, main compressed-tensors, nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change, main compressed-tensors, nm-testing/tinyllama-oneshot-w8-channel-a8-tensor, main compressed-tensors, nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2, main @@ -13,8 +19,7 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main -compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main -compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main +compressed-tensors, nm-testing/TinyLlama-1.1B-Chat-v1.0-actorder-group, main awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main diff --git a/tests/weight_loading/test_weight_loading.py b/tests/weight_loading/test_weight_loading.py index c13313df93f66..d8bca05e204c0 100644 --- a/tests/weight_loading/test_weight_loading.py +++ b/tests/weight_loading/test_weight_loading.py @@ -1,5 +1,7 @@ import os +import torch + MAX_MODEL_LEN = 1024 MODEL_NAME = os.environ.get("MODEL_NAME", "robertgshaw2/zephyr-7b-beta-channelwise-gptq") @@ -8,9 +10,12 @@ def test_weight_loading(vllm_runner): + """ + Test parameter weight loading with tp>1. + """ with vllm_runner(model_name=MODEL_NAME, revision=REVISION, - dtype="auto", + dtype=torch.half if QUANTIZATION == "gptq" else "auto", quantization=QUANTIZATION, max_model_len=MAX_MODEL_LEN, tensor_parallel_size=2) as model: diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index ae90af563c0cf..151cdbee8eb04 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -4,6 +4,7 @@ import torch +import vllm.envs as envs from vllm._core_ext import ScalarType from vllm.logger import init_logger from vllm.platforms import current_platform @@ -177,12 +178,20 @@ def advance_step(num_seqs: int, num_queries: int, block_size: int, def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, zeros: torch.Tensor, split_k_iters: int, thx: int, thy: int) -> torch.Tensor: + if envs.VLLM_USE_TRITON_AWQ: + from vllm.model_executor.layers.quantization.awq_triton import ( + awq_dequantize_triton) + return awq_dequantize_triton(qweight, scales, zeros) return torch.ops._C.awq_dequantize(qweight, scales, zeros, split_k_iters, thx, thy) def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: + if envs.VLLM_USE_TRITON_AWQ: + from vllm.model_executor.layers.quantization.awq_triton import ( + awq_gemm_triton) + return awq_gemm_triton(input, qweight, qzeros, scales, split_k_iters) return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters) @@ -200,12 +209,6 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) -# squeezellm -def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, - lookup_table: torch.Tensor) -> None: - torch.ops._C.squeezellm_gemm(vec, mat, mul, lookup_table) - - # marlin def marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, workspace: torch.Tensor, size_m: int, @@ -491,6 +494,36 @@ def ggml_mul_mat_a8( return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row) +# mamba +def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor, + bias_: Optional[torch.Tensor], + seq_idx_: Optional[torch.Tensor], + initial_states_: Optional[torch.Tensor], + final_states_out_: Optional[torch.Tensor], + silu_activation: bool) -> torch.Tensor: + return torch.ops._C.causal_conv1d_fwd(x, weight, bias_, seq_idx_, + initial_states_, final_states_out_, + silu_activation) + + +def causal_conv1d_update(x: torch.Tensor, conv_state: torch.Tensor, + weight: torch.Tensor, bias_: Optional[torch.Tensor], + silu_activation: bool) -> torch.Tensor: + return torch.ops._C.causal_conv1d_update(x, conv_state, weight, bias_, + silu_activation) + + +def selective_scan_fwd(u: torch.Tensor, delta: torch.Tensor, A: torch.Tensor, + B: torch.Tensor, C: torch.Tensor, + D_: Optional[torch.Tensor], z_: Optional[torch.Tensor], + delta_bias_: Optional[torch.Tensor], + delta_softplus: bool, index_: Optional[torch.Tensor], + x: Optional[torch.Tensor]) -> List[torch.Tensor]: + return torch.ops._C.selective_scan_fwd(u, delta, A, B, C, D_, z_, + delta_bias_, delta_softplus, index_, + x) + + # moe def moe_align_block_size(topk_ids: torch.Tensor, num_experts: int, block_size: int, sorted_token_ids: torch.Tensor, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index a8d76b79ff204..7aec8203eb1e5 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -83,6 +83,15 @@ def copy_blocks( def get_supported_head_sizes() -> List[int]: return [64, 128, 256] + @staticmethod + def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype: + if kv_cache_dtype in ("fp8", "fp8_e4m3"): + return torch.float8_e4m3fn + elif kv_cache_dtype == "fp8_e5m2": + return torch.float8_e5m2 + else: + raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}") + class FlashInferState(AttentionState): @@ -177,8 +186,12 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): self._graph_decode_workspace_buffer, _indptr_buffer, self._graph_indices_buffer, _last_page_len_buffer, "NHD", use_tensor_cores) - kv_cache_dtype = get_kv_cache_torch_dtype( - self.runner.kv_cache_dtype, self.runner.model_config.dtype) + if self.runner.kv_cache_dtype.startswith("fp8"): + kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.runner.kv_cache_dtype) + else: + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) paged_kv_indptr_tensor_host = torch.arange(0, batch_size + 1, @@ -211,6 +224,7 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): query_start_loc=query_start_loc_host, device=self.runner.device, data_type=kv_cache_dtype, + q_data_type=self.runner.model_config.dtype, use_cuda_graph=True, decode_wrapper=self._graph_decode_wrapper, prefill_wrapper=None) @@ -279,6 +293,8 @@ class FlashInferMetadata(AttentionMetadata): page_size: Optional[int] = None # The data type of the paged kv cache data_type: torch.dtype = None + # The data type of the query + q_data_type: torch.dtype = None device: torch.device = torch.device("cuda") is_profile_run: bool = False @@ -340,7 +356,10 @@ def begin_forward(self): self.page_size, # Disable flashinfer's pos encoding and use vllm's rope. pos_encoding_mode="NONE", - data_type=self.data_type) + # kv-cache data type. + data_type=self.data_type, + # query data type. + q_data_type=self.q_data_type) def asdict_zerocopy(self, skip_fields: Optional[Set[str]] = None @@ -366,7 +385,8 @@ def prefill_metadata(self) -> Optional["FlashInferMetadata"]: def decode_metadata(self) -> Optional["FlashInferMetadata"]: # Currently chunked prefill is not supported if self.num_prefills > 0: - assert self.num_decode_tokens == 0 + assert self.num_decode_tokens == 0, ( + "Chunked prefill is not supported with flashinfer yet.") return None return self @@ -576,8 +596,13 @@ def build(self, seq_lens: List[int], query_lens: List[int], paged_kv_indptr_tensor = None paged_kv_last_page_len_tensor = None - kv_cache_dtype = get_kv_cache_torch_dtype( - self.runner.kv_cache_dtype, self.runner.model_config.dtype) + if self.runner.kv_cache_dtype.startswith("fp8"): + kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.runner.kv_cache_dtype) + else: + kv_cache_dtype = get_kv_cache_torch_dtype( + self.runner.kv_cache_dtype, self.runner.model_config.dtype) + return FlashInferMetadata( num_prefills=self.num_prefills, slot_mapping=slot_mapping_tensor, @@ -598,6 +623,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], query_start_loc=query_start_loc, device=device, data_type=kv_cache_dtype, + q_data_type=self.runner.model_config.dtype, use_cuda_graph=use_captured_graph, is_profile_run=self.is_profile_run) @@ -661,7 +687,6 @@ def forward( if attn_metadata.num_decode_tokens > 0: assert attn_metadata.num_prefill_tokens == 0, ( "Chunked prefill is not supported with flashinfer yet.") - if kv_cache is not None: # Use the same reshape and cache kernel as flash attention. ops.reshape_and_cache_flash( @@ -674,6 +699,12 @@ def forward( k_scale, v_scale, ) + # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 + # to process the cache when the kv_cache_dtype is fp8 + if self.kv_cache_dtype.startswith("fp8"): + torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + self.kv_cache_dtype) + kv_cache = kv_cache.view(torch_dtype) query = query.contiguous( ) # Flashinfer requires query to be contiguous @@ -711,5 +742,7 @@ def forward( query, kv_cache, sm_scale=self.scale, - logits_soft_cap=self.logits_soft_cap) + logits_soft_cap=self.logits_soft_cap, + k_scale=k_scale, + v_scale=v_scale) return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index ac03b6d8b1ead..83fdef16ef5cb 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -123,7 +123,13 @@ def __init__( raise NotImplementedError("TPU version must be 4 or higher.") self.megacore_mode = None - tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower() + tpu_env = torch_xla.tpu.get_tpu_env() + tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None) + or tpu_env.get("TYPE", None) + or tpu_env.get("TPU_ACCELERATOR_TYPE", None)) + assert tpu_type is not None + tpu_type = tpu_type.lower() + if "lite" not in tpu_type: if self.num_kv_heads % 2 == 0: self.megacore_mode = "kv_head" diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 54558fc2d7e53..855586d4e5961 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -226,6 +226,10 @@ def which_attn_to_use( elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"): logger.info( "Cannot use FlashAttention-2 backend for FP8 KV cache.") + logger.warning( + "Please use FlashInfer backend with FP8 KV Cache for " + "better performance by setting environment variable " + "VLLM_ATTENTION_BACKEND=FLASHINFER") selected_backend = _Backend.XFORMERS elif block_size % 16 != 0: logger.info( diff --git a/vllm/compilation/__init__.py b/vllm/compilation/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py new file mode 100644 index 0000000000000..e923bd36ccc08 --- /dev/null +++ b/vllm/compilation/wrapper.py @@ -0,0 +1,81 @@ +import os +import sys +from abc import abstractmethod +from contextlib import contextmanager +from types import CodeType +from typing import Callable, List + +import torch + +import vllm.envs as envs + + +class TorchCompileWrapperWithCustomDispatcher: + """ + A wrapper class for torch.compile, with a custom dispatch logic. + Subclasses should: + 1. Implement the forward method + 2. Implement the dispatch logic in the __call__ method + It can use `self.compiled_codes` to access the compiled bytecode, + and `with self.dispatch_to_code(index):` to dispatch to + the compiled code. + 3. Implement the `__init__` method to determine how to call + `torch.compile` over the forward method. + """ + + def __init__(self, compiled_callable: Callable): + self.compiled_callable = compiled_callable + self.original_code_object = self.__class__.forward.__code__ + self.compiled_codes: List[CodeType] = [] + torch._dynamo.convert_frame.register_bytecode_hook(self.bytecode_hook) + + # read the env var to determine whether to use the custom dispatcher + # subclasses can use this to switch between the custom dispatcher + # and the default Dynamo guard mechanism. + self.use_custom_dispatcher: bool = \ + envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER + + def __call__(self, *args, **kwargs): + """Implement the dispatch logic here, beyond the torch.compile level. + NOTE: this function can have additional arguments beyond the forward + method, for directly dispatching to the compiled code. + """ + return self.compiled_callable(*args, **kwargs) + + @abstractmethod + def forward(self, *args, **kwargs): + ... + + def bytecode_hook(self, old_code: CodeType, new_code: CodeType): + """Hook to save the compiled bytecode for direct execution.""" + if old_code is not self.original_code_object: + return + # code borrowed from https://github.com/thuml/depyf/blob/f4ad79fadee27ea113b4c75202db1eb1a11c0dbc/depyf/explain/enable_debugging.py#L25 + frame = sys._getframe() + while True: + frame = frame.f_back + code_name = frame.f_code.co_name + file_name = frame.f_code.co_filename.split(os.path.sep)[-1] + if code_name == "_compile" and file_name == "convert_frame.py": + break + frame = frame.f_locals["frame"] + assert frame.f_code == old_code + + if frame.f_locals["self"] is not self: + return + + self.compiled_codes.append(new_code) + + @contextmanager + def dispatch_to_code(self, index: int): + """Context manager to dispatch to the compiled code. + Why does this work? Because Dynamo guarantees that the compiled + bytecode has exactly the same arguments, cell variables, and free + variables as the original code. Therefore we can directly switch + the code object in the function and call it. + + See https://dev-discuss.pytorch.org/t/what-is-the-relationship-requirement-among-original-bytecode-transformed-bytecode-and-bytecode-returned-by-hooks-in-dynamo/1693/7 for more details. + """ # noqa + self.__class__.forward.__code__ = self.compiled_codes[index] + yield + self.__class__.forward.__code__ = self.original_code_object diff --git a/vllm/config.py b/vllm/config.py index 4e014e43d849a..8f5e02e35f28d 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,8 +1,8 @@ import enum import json from dataclasses import dataclass, field, fields -from typing import (TYPE_CHECKING, ClassVar, List, Mapping, Optional, Tuple, - Type, Union) +from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping, + Optional, Tuple, Type, Union) import torch from transformers import PretrainedConfig @@ -13,7 +13,7 @@ from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform from vllm.tracing import is_otel_available, otel_import_error_traceback -from vllm.transformers_utils.config import (get_config, +from vllm.transformers_utils.config import (ConfigFormat, get_config, get_hf_image_processor_config, get_hf_text_config) from vllm.utils import (STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, GiB_bytes, @@ -32,20 +32,23 @@ logger = init_logger(__name__) _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 +_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096 _PP_SUPPORTED_MODELS = [ - "AquilaModel", "AquilaForCausalLM", + "AquilaModel", "DeepseekV2ForCausalLM", + "GPT2LMHeadModel", + "InternLM2ForCausalLM", "InternLMForCausalLM", + "InternVLChatModel", "JAISLMHeadModel", "LlamaForCausalLM", "LLaMAForCausalLM", "MistralForCausalLM", - "Phi3ForCausalLM", - "GPT2LMHeadModel", "MixtralForCausalLM", "NemotronForCausalLM", + "Phi3ForCausalLM", "Qwen2ForCausalLM", "Qwen2MoeForCausalLM", "QWenLMHeadModel", @@ -114,35 +117,41 @@ class ModelConfig: the model name will be the same as `model`. limit_mm_per_prompt: Maximum number of data instances per modality per prompt. Only applicable for multimodal models. + override_neuron_config: Initialize non default neuron config or + override default neuron config that are specific to Neuron devices, + this argument will be used to configure the neuron config that + can not be gathered from the vllm arguments. + config_format: The config format which shall be loaded. + Defaults to 'auto' which defaults to 'hf'. """ - def __init__( - self, - model: str, - tokenizer: str, - tokenizer_mode: str, - trust_remote_code: bool, - dtype: Union[str, torch.dtype], - seed: int, - revision: Optional[str] = None, - code_revision: Optional[str] = None, - rope_scaling: Optional[dict] = None, - rope_theta: Optional[float] = None, - tokenizer_revision: Optional[str] = None, - max_model_len: Optional[int] = None, - spec_target_max_model_len: Optional[int] = None, - quantization: Optional[str] = None, - quantization_param_path: Optional[str] = None, - enforce_eager: Optional[bool] = None, - max_context_len_to_capture: Optional[int] = None, - max_seq_len_to_capture: Optional[int] = None, - max_logprobs: int = 20, - disable_sliding_window: bool = False, - skip_tokenizer_init: bool = False, - served_model_name: Optional[Union[str, List[str]]] = None, - limit_mm_per_prompt: Optional[Mapping[str, int]] = None, - use_async_output_proc: bool = True, - ) -> None: + def __init__(self, + model: str, + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + spec_target_max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + enforce_eager: Optional[bool] = None, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 20, + disable_sliding_window: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, + limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + use_async_output_proc: bool = True, + override_neuron_config: Optional[Dict[str, Any]] = None, + config_format: ConfigFormat = ConfigFormat.AUTO) -> None: self.model = model self.tokenizer = tokenizer self.tokenizer_mode = tokenizer_mode @@ -169,7 +178,8 @@ def __init__( self.skip_tokenizer_init = skip_tokenizer_init self.hf_config = get_config(self.model, trust_remote_code, revision, - code_revision, rope_scaling, rope_theta) + code_revision, rope_scaling, rope_theta, + config_format) self.hf_text_config = get_hf_text_config(self.hf_config) self.hf_image_processor_config = get_hf_image_processor_config( self.model, revision) @@ -226,6 +236,9 @@ def __init__( limit_mm_per_prompt) if not self.skip_tokenizer_init: self._verify_tokenizer_mode() + + self.override_neuron_config = override_neuron_config if is_neuron( + ) else None self._verify_embedding_mode() self._verify_quantization() self._verify_cuda_graph() @@ -267,13 +280,14 @@ def _parse_quant_hf_config(self): def _verify_quantization(self) -> None: supported_quantization = [*QUANTIZATION_METHODS] - rocm_supported_quantization = ["gptq", "squeezellm", "fp8"] + rocm_supported_quantization = ["awq", "gptq", "fp8"] optimized_quantization_methods = [ "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", "fbgemm_fp8", "compressed_tensors", "compressed-tensors", "experts_int8" ] tpu_supported_quantization = ["tpu_int8"] + neuron_supported_quantization = ["neuron_quant"] if self.quantization is not None: self.quantization = self.quantization.lower() @@ -322,6 +336,17 @@ def _verify_quantization(self) -> None: "%s quantization is not fully " "optimized yet. The speed can be slower than " "non-quantized models.", self.quantization) + if (self.quantization == "awq" and is_hip() + and not envs.VLLM_USE_TRITON_AWQ): + logger.warning( + "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" + " is not set, enabling VLLM_USE_TRITON_AWQ.") + envs.VLLM_USE_TRITON_AWQ = True + if is_neuron( + ) and self.quantization not in neuron_supported_quantization: + raise ValueError( + f"{self.quantization} quantization is currently not " + f"supported in Neuron Backend.") def _verify_cuda_graph(self) -> None: if self.max_seq_len_to_capture is None: @@ -341,10 +366,10 @@ def verify_async_output_proc(self, parallel_config, speculative_config, self.use_async_output_proc = False return - if device_config.device_type != "cuda": + if device_config.device_type not in ("cuda", "tpu"): logger.warning( - "Async output processing is only supported for CUDA." - " Disabling it for other platforms.") + "Async output processing is only supported for CUDA or TPU. " + "Disabling it for other platforms.") self.use_async_output_proc = False return @@ -399,6 +424,8 @@ def verify_with_parallel_config( raise ValueError( "BitAndBytes quantization with TP or PP is not supported yet.") + # Remove the constraint after the bitsandbytes issue is fixed: + # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308 if self.quantization == "bitsandbytes" and self.enforce_eager is False: logger.warning("CUDA graph is not supported on BitAndBytes yet, " "fallback to the eager mode.") @@ -563,6 +590,10 @@ def is_embedding_model(self) -> bool: """Extract the embedding model flag.""" return self.embedding_mode + @property + def is_multimodal_model(self) -> bool: + return self.multimodal_config is not None + class CacheConfig: """Configuration for the KV cache. @@ -718,6 +749,7 @@ class LoadFormat(str, enum.Enum): SHARDED_STATE = "sharded_state" GGUF = "gguf" BITSANDBYTES = "bitsandbytes" + MISTRAL = "mistral" @dataclass @@ -939,25 +971,36 @@ def __init__(self, num_lookahead_slots: int = 0, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, - embedding_mode: Optional[bool] = False, + embedding_mode: bool = False, + is_multimodal_model: bool = False, preemption_mode: Optional[str] = None, num_scheduler_steps: int = 1, send_delta_data: bool = False) -> None: - if max_num_batched_tokens is not None: - self.max_num_batched_tokens = max_num_batched_tokens - else: + if max_num_batched_tokens is None: if enable_chunked_prefill: # It is the values that have the best balance between ITL # and TTFT on A100. Note it is not optimized for throughput. - self.max_num_batched_tokens = 512 - elif embedding_mode: - # For embedding, choose specific value for higher throughput - self.max_num_batched_tokens = max( - max_model_len, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS) + max_num_batched_tokens = 512 else: # If max_model_len is too short, use 2048 as the default value # for higher throughput. - self.max_num_batched_tokens = max(max_model_len, 2048) + max_num_batched_tokens = max(max_model_len, 2048) + + if embedding_mode: + # For embedding, choose specific value for higher throughput + max_num_batched_tokens = max( + max_num_batched_tokens, + _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + if is_multimodal_model: + # The value needs to be at least the number of multimodal tokens + max_num_batched_tokens = max( + max_num_batched_tokens, + _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, + ) + + self.max_num_batched_tokens = max_num_batched_tokens + if enable_chunked_prefill: logger.info( "Chunked prefill is enabled with max_num_batched_tokens=%d.", @@ -1498,7 +1541,7 @@ def verify_with_model_config(self, model_config: ModelConfig): if model_config.quantization and model_config.quantization not in [ "awq", "gptq" ]: - # TODO support marlin and squeezellm + # TODO support marlin logger.warning("%s quantization is not tested with LoRA yet.", model_config.quantization) diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 666723313c829..24ab9eb66194d 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -681,14 +681,20 @@ def access_all_blocks_in_seq( for block in block_table: block.last_accessed = access_time - def compute_full_blocks_in_seq(self, seq: Sequence): + def compute_full_blocks_in_seq(self, seq: Sequence, token_chunk_size: int): if seq.seq_id not in self.block_tables: return - max_full_block = seq.get_len() // self.block_size - 1 + + # When chunked prefill is enabled, the computed full blocks + # should be calculated based on the number of computed tokens. + max_computed_tokens = (seq.data.get_num_computed_tokens() + + token_chunk_size) + computed_full_blocks = max_computed_tokens // self.block_size + block_table = self.block_tables[seq.seq_id] - if max_full_block == -1: + if computed_full_blocks == 0: return - for i in reversed(range(max_full_block)): + for i in reversed(range(computed_full_blocks)): if block_table[i].computed: break block_table[i].computed = True @@ -718,10 +724,11 @@ def get_common_computed_block_ids( ids_list = [self.get_all_computed_blocks(seq) for seq in seqs] return commonprefix([ids for ids in ids_list if ids != []]) - def mark_blocks_as_computed(self, seq_group: SequenceGroup): + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): if self.enable_caching: for seq in seq_group.get_seqs(): - self.compute_full_blocks_in_seq(seq) + self.compute_full_blocks_in_seq(seq, token_chunk_size) def get_prefix_cache_hit_rate(self, device: Device) -> float: if device == Device.GPU: diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index 7d2db43cb4602..b06385b062e83 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -290,7 +290,8 @@ def access_all_blocks_in_seq(self, seq: Sequence, now: float): self._last_access_blocks_tracker.update_last_access( seq.seq_id, now) - def mark_blocks_as_computed(self, seq_group: SequenceGroup): + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): # If prefix caching is enabled, mark immutable blocks as computed # right after they have been scheduled (for prefill). This assumes # the scheduler is synchronous so blocks are actually computed when diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py index f16f66e99e7f8..c47d7d8dfb075 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/embedding_model_block_manager.py @@ -80,7 +80,8 @@ def get_common_computed_block_ids(self, seq_group: List[Sequence]) -> List[int]: return [] - def mark_blocks_as_computed(self, seq_group: SequenceGroup): + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): pass def get_prefix_cache_hit_rate(self, device: Device) -> float: diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index becd0d2e7f849..96f8dd851b2f4 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -115,7 +115,8 @@ def get_common_computed_block_ids( pass @abstractmethod - def mark_blocks_as_computed(self, seq_group: SequenceGroup): + def mark_blocks_as_computed(self, seq_group: SequenceGroup, + token_chunk_size: int): pass @abstractmethod diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index f0ebb49c9eae9..5c191ef60b785 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -537,13 +537,6 @@ def _schedule_running( preempted: List[SequenceGroup] = ret.preempted swapped_out: List[SequenceGroup] = ret.swapped_out - # NOTE(woosuk): Preemption happens only when there is no available slot - # to keep all the sequence groups in the RUNNING state. - - # Store original running requests for the case of async + preemption - if self.use_async_output_proc: - orig_running = self.running.copy() - running_queue = self.running assert len(self._async_stopped) == 0 while running_queue: @@ -552,6 +545,7 @@ def _schedule_running( seq_group, SequenceStatus.RUNNING, enable_chunking, budget) if num_running_tokens == 0: + # No budget => Stop break running_queue.popleft() @@ -565,19 +559,11 @@ def _schedule_running( self._async_stopped.append(seq_group) continue - # With async postprocessor, when preemption kicks in, we need - # first to drain the async postprocessor, so that all async - # block_table freeing is applied before the preemption freeing - # is applied. - if self.use_async_output_proc and not self._can_append_slots( - seq_group): - tmp = self.running - self.running = orig_running - assert self.output_proc_callback is not None - self.output_proc_callback() - self.running = tmp - - while not True: #TODO #self._can_append_slots(seq_group): + # NOTE(woosuk): Preemption happens only when there is no available + # slot to keep all the sequence groups in the RUNNING state. + # FIXME(hyunjun): vLLM does not management LPU resources. + # But it should be modified to support GPU, too. + while not True: #self._can_append_slots(seq_group): budget.subtract_num_batched_tokens(seq_group.request_id, num_running_tokens) num_running_seqs = seq_group.get_max_num_running_seqs() @@ -588,26 +574,46 @@ def _schedule_running( and seq_group.lora_int_id in curr_loras): curr_loras.remove(seq_group.lora_int_id) + # Determine victim sequence + cont_loop = True if running_queue: - # Preempt the lowest-priority sequence groups. + # Preempt the lowest-priority sequence group. victim_seq_group = running_queue.pop() + else: + # No other sequence group can be preempted. + # Preempt the current sequence group. + # Note: This is also where we stop this loop + # (since there is nothing else to preempt) + victim_seq_group = seq_group + cont_loop = False + + # With async postprocessor, before preempting a sequence + # we need to ensure it has no pending async postprocessor + do_preempt = True + if self.use_async_output_proc: + assert self.output_proc_callback is not None + self.output_proc_callback( + request_id=victim_seq_group.request_id) + + # It may be that the async pending "victim_seq_group" + # becomes finished, in which case we simply free it. + if victim_seq_group.is_finished(): + self._free_finished_seq_group(victim_seq_group) + do_preempt = False + + # Do preemption + if do_preempt: preempted_mode = self._preempt(victim_seq_group, blocks_to_swap_out) if preempted_mode == PreemptionMode.RECOMPUTE: preempted.append(victim_seq_group) else: swapped_out.append(victim_seq_group) - else: - # No other sequence groups can be preempted. - # Preempt the current sequence group. - preempted_mode = self._preempt(seq_group, - blocks_to_swap_out) - if preempted_mode == PreemptionMode.RECOMPUTE: - preempted.append(seq_group) - else: - swapped_out.append(seq_group) + + if not cont_loop: break else: + # FIXME(hyunjun): vLLM does not management LPU resources. #self._append_slots(seq_group, blocks_to_copy) is_prefill = seq_group.is_prefill() scheduled_seq_group: ScheduledSequenceGroup = \ @@ -1026,16 +1032,21 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: # Update waiting requests. self.waiting.extendleft(running_scheduled.preempted) + # Update new running requests. - self.running.extend([s.seq_group for s in prefills.seq_groups]) - self.running.extend( - [s.seq_group for s in running_scheduled.decode_seq_groups]) - self.running.extend( - [s.seq_group for s in running_scheduled.prefill_seq_groups]) + # By default, vLLM scheduler prioritizes prefills. + # Once chunked prefill is enabled, + # the policy is changed to prioritize decode requests. self.running.extend( [s.seq_group for s in swapped_in.decode_seq_groups]) self.running.extend( [s.seq_group for s in swapped_in.prefill_seq_groups]) + self.running.extend( + [s.seq_group for s in running_scheduled.decode_seq_groups]) + self.running.extend( + [s.seq_group for s in running_scheduled.prefill_seq_groups]) + self.running.extend([s.seq_group for s in prefills.seq_groups]) + # Update swapped requests. self.swapped.extend(running_scheduled.swapped_out) return SchedulerOutputs( @@ -1106,10 +1117,7 @@ def schedule( if not self.cache_config.enable_prefix_caching: common_computed_block_nums = [] - # TODO: Combine multi-step and async postprocessor - allow_async_output_proc: bool = ( - self.use_async_output_proc - and not self.scheduler_config.is_multi_step) + allow_async_output_proc: bool = self.use_async_output_proc # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] @@ -1225,7 +1233,8 @@ def schedule( # will crash the vLLM instance / will not retry. for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: self.block_manager.mark_blocks_as_computed( - scheduled_seq_group.seq_group) + scheduled_seq_group.seq_group, + scheduled_seq_group.token_chunk_size) self._seq_group_metadata_cache[self.next_cache_id].reset() @@ -1260,22 +1269,26 @@ def _free_finished_seqs(self, seq_group: SequenceGroup) -> None: if seq.is_finished(): self.free_seq(seq) + def _free_finished_seq_group(self, seq_group: SequenceGroup) -> None: + if seq_group.is_finished(): + # Free cross-attention block table, if it exists + self._free_seq_group_cross_attn_blocks(seq_group) + + # Add the finished requests to the finished requests list. + # This list will be used to update the Mamba cache in the + # next step. + self._finished_requests_ids.append(seq_group.request_id) + + # Free finished seqs + self._free_finished_seqs(seq_group) + def free_finished_seq_groups(self) -> None: remaining: Deque[SequenceGroup] = deque() for seq_group in self.running: - if seq_group.is_finished(): - # Free cross-attention block table, if it exists - self._free_seq_group_cross_attn_blocks(seq_group) - # Add the finished requests to the finished requests list. - # This list will be used to update the Mamba cache in the - # next step. - self._finished_requests_ids.append(seq_group.request_id) - else: + self._free_finished_seq_group(seq_group) + if not seq_group.is_finished(): remaining.append(seq_group) - # Free finished seqs - self._free_finished_seqs(seq_group) - self.running = remaining # Handle async stopped sequence groups @@ -1456,10 +1469,27 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, for seq in seqs: num_new_tokens += seq.get_num_new_tokens() assert num_new_tokens > 0 - # Chunk if a running request cannot fit in. - # If number of seq > 1, it means it is doing beam search in a - # decode phase. Do not chunk in that case. + # Chunk if a running request cannot fit in the given budget. + # If number of seq > 1, it means it is doing beam search + # in a decode phase. Do not chunk. if enable_chunking and len(seqs) == 1: - num_new_tokens = min(num_new_tokens, - budget.remaining_token_budget()) + remaining_token_budget = budget.remaining_token_budget() + if self.cache_config.enable_prefix_caching: + # When prefix caching is enabled, we always allocate + # the number of new tokens that is dividable by the block size + # to avoid partial block matching. + block_size = self.cache_config.block_size + reminder = budget.token_budget % block_size + if reminder != 0: + raise ValueError("When enabling chunked prefill and " + "prefix caching, max_num_batched_tokens " + "(chunk size) must be dividable by " + "block size, but got chunk_size " + f"({budget.token_budget}) % block_size " + f"({block_size}) = {reminder}") + if remaining_token_budget < num_new_tokens: + num_new_tokens = (remaining_token_budget // + block_size) * block_size + else: + num_new_tokens = min(num_new_tokens, remaining_token_budget) return num_new_tokens diff --git a/vllm/distributed/device_communicators/tpu_communicator.py b/vllm/distributed/device_communicators/tpu_communicator.py index 81a141e86206a..765a0f9cb1c87 100644 --- a/vllm/distributed/device_communicators/tpu_communicator.py +++ b/vllm/distributed/device_communicators/tpu_communicator.py @@ -1,3 +1,5 @@ +import os + import torch import torch.distributed as dist from torch.distributed import ProcessGroup @@ -5,11 +7,12 @@ from vllm.platforms import current_platform if current_platform.is_tpu(): - import ray import torch_xla.core.xla_model as xm import torch_xla.runtime as xr from torch_xla._internal import pjrt + from vllm.executor import ray_utils + class TpuCommunicator: @@ -24,9 +27,29 @@ def __init__(self, group: ProcessGroup): # be simply calculated as follows. global_rank = dist.get_rank(group) global_world_size = dist.get_world_size(group) - num_nodes = len(ray.nodes()) + + # Calculate how many TPU nodes are in the current deployment. This + # is the Ray placement group if it is deployed with Ray. Default + # to the number of TPU nodes in the Ray cluster. The number of TPU + # nodes is computed by the total number of TPUs divided by the + # number of TPU accelerators per node, to account for clusters + # with both CPUs and TPUs. + num_nodes = ray_utils.get_num_tpu_nodes() + num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group() + if num_nodes_in_pg > 0: + num_nodes = num_nodes_in_pg + local_world_size = global_world_size // num_nodes local_rank = global_rank % local_world_size + + # Ensure environment variables are set for multihost deployments. + # On GKE, this is needed for libtpu and TPU driver to know which TPU + # chip is actually visible. Otherwise the TPU driver will fail to + # initialize because the number of devices would be different from + # the number of visible worker addresses. + os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank) + os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank) + pjrt.initialize_multiprocess(local_rank, local_world_size) xr._init_world_size_ordinal() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b4c25ea0abf66..1c5f3c120857d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -2,20 +2,21 @@ import dataclasses import json from dataclasses import dataclass -from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Tuple, Type, - Union) +from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, + Type, Union) import torch import vllm.envs as envs -from vllm.config import (CacheConfig, DecodingConfig, DeviceConfig, - EngineConfig, LoadConfig, LoadFormat, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, +from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig, + DeviceConfig, EngineConfig, LoadConfig, LoadFormat, + LoRAConfig, ModelConfig, ObservabilityConfig, + ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig, TokenizerPoolConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import FlexibleArgumentParser if TYPE_CHECKING: @@ -64,6 +65,7 @@ class EngineArgs: trust_remote_code: bool = False download_dir: Optional[str] = None load_format: str = 'auto' + config_format: str = 'auto' dtype: str = 'auto' kv_cache_dtype: str = 'auto' quantization_param_path: Optional[str] = None @@ -148,6 +150,7 @@ class EngineArgs: otlp_traces_endpoint: Optional[str] = None collect_detailed_traces: Optional[str] = None disable_async_output_proc: bool = False + override_neuron_config: Optional[Dict[str, Any]] = None def __post_init__(self): if self.tokenizer is None: @@ -232,6 +235,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'section for more information.\n' '* "bitsandbytes" will load the weights using bitsandbytes ' 'quantization.\n') + parser.add_argument( + '--config-format', + default=EngineArgs.config_format, + choices=[f.value for f in ConfigFormat], + help='The format of the model config to load.\n\n' + '* "auto" will try to load the config in hf format ' + 'if available else it will try to load in mistral format ') parser.add_argument( '--dtype', type=str, @@ -741,6 +751,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.disable_async_output_proc, help="Disable async output processing. This may result in " "lower performance.") + parser.add_argument( + '--override-neuron-config', + type=lambda configs: { + str(key): value + for key, value in + (config.split(':') for config in configs.split(',')) + }, + default=None, + help="override or set neuron device configuration.") + return parser @classmethod @@ -751,9 +771,46 @@ def from_cli_args(cls, args: argparse.Namespace): engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) return engine_args + def create_model_config(self) -> ModelConfig: + return ModelConfig( + model=self.model, + tokenizer=self.tokenizer, + tokenizer_mode=self.tokenizer_mode, + trust_remote_code=self.trust_remote_code, + dtype=self.dtype, + seed=self.seed, + revision=self.revision, + code_revision=self.code_revision, + rope_scaling=self.rope_scaling, + rope_theta=self.rope_theta, + tokenizer_revision=self.tokenizer_revision, + max_model_len=self.max_model_len, + quantization=self.quantization, + quantization_param_path=self.quantization_param_path, + enforce_eager=self.enforce_eager, + max_context_len_to_capture=self.max_context_len_to_capture, + max_seq_len_to_capture=self.max_seq_len_to_capture, + max_logprobs=self.max_logprobs, + disable_sliding_window=self.disable_sliding_window, + skip_tokenizer_init=self.skip_tokenizer_init, + served_model_name=self.served_model_name, + limit_mm_per_prompt=self.limit_mm_per_prompt, + use_async_output_proc=not self.disable_async_output_proc, + override_neuron_config=self.override_neuron_config, + config_format=self.config_format, + ) + + def create_load_config(self) -> LoadConfig: + return LoadConfig( + load_format=self.load_format, + download_dir=self.download_dir, + model_loader_extra_config=self.model_loader_extra_config, + ignore_patterns=self.ignore_patterns, + ) + def create_engine_config(self) -> EngineConfig: # gguf file needs a specific model loader and doesn't use hf_repo - if self.model.endswith(".gguf"): + if check_gguf_file(self.model): self.quantization = self.load_format = "gguf" # bitsandbytes quantization needs a specific model loader @@ -777,31 +834,8 @@ def create_engine_config(self) -> EngineConfig: f", but got {self.cpu_offload_gb}") device_config = DeviceConfig(device=self.device) - model_config = ModelConfig( - model=self.model, - tokenizer=self.tokenizer, - tokenizer_mode=self.tokenizer_mode, - trust_remote_code=self.trust_remote_code, - dtype=self.dtype, - seed=self.seed, - revision=self.revision, - code_revision=self.code_revision, - rope_scaling=self.rope_scaling, - rope_theta=self.rope_theta, - tokenizer_revision=self.tokenizer_revision, - max_model_len=self.max_model_len, - quantization=self.quantization, - quantization_param_path=self.quantization_param_path, - enforce_eager=self.enforce_eager, - max_context_len_to_capture=self.max_context_len_to_capture, - max_seq_len_to_capture=self.max_seq_len_to_capture, - max_logprobs=self.max_logprobs, - disable_sliding_window=self.disable_sliding_window, - skip_tokenizer_init=self.skip_tokenizer_init, - served_model_name=self.served_model_name, - limit_mm_per_prompt=self.limit_mm_per_prompt, - use_async_output_proc=not self.disable_async_output_proc, - ) + model_config = self.create_model_config() + cache_config = CacheConfig( block_size=self.block_size if self.device != "neuron" else self.max_model_len, # neuron needs block_size = max_model_len @@ -921,6 +955,7 @@ def create_engine_config(self) -> EngineConfig: delay_factor=self.scheduler_delay_factor, enable_chunked_prefill=self.enable_chunked_prefill, embedding_mode=model_config.embedding_mode, + is_multimodal_model=model_config.is_multimodal_model, preemption_mode=self.preemption_mode, num_scheduler_steps=self.num_scheduler_steps, send_delta_data=(envs.VLLM_USE_RAY_SPMD_WORKER @@ -943,12 +978,7 @@ def create_engine_config(self) -> EngineConfig: self.model_loader_extra_config[ "qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path - load_config = LoadConfig( - load_format=self.load_format, - download_dir=self.download_dir, - model_loader_extra_config=self.model_loader_extra_config, - ignore_patterns=self.ignore_patterns, - ) + load_config = self.create_load_config() prompt_adapter_config = PromptAdapterConfig( max_prompt_adapters=self.max_prompt_adapters, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 0946234ba262b..99afc58af3550 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -22,11 +22,12 @@ from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.logger import init_logger, print_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import print_warning_once @@ -281,24 +282,25 @@ async def step_async( ctx = self.scheduler_contexts[virtual_engine] + # Clear outputs for each new scheduler iteration + ctx.request_outputs.clear() + # skip the scheduler if there are any remaining steps in the seq groups. # This ensures that the scheduler is only called again when the current # batch has completed. if not self._has_remaining_steps(seq_group_metadata_list): - # Clear outputs on scheduler iteration start - ctx.request_outputs.clear() - + # Schedule iteration (seq_group_metadata_list, scheduler_outputs, allow_async_output_proc ) = self.scheduler[virtual_engine].schedule() - # If current scheduler iteration has no async postprocessor, - # then we need first to drain the pending async postprocessor - # before moving forward + ctx.seq_group_metadata_list = seq_group_metadata_list + ctx.scheduler_outputs = scheduler_outputs + + # Maybe switch from async mode to sync mode if not allow_async_output_proc and len(ctx.output_queue) > 0: - self._process_model_outputs(virtual_engine=virtual_engine, - is_async=True) + self._process_model_outputs(ctx=ctx) if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): @@ -311,9 +313,6 @@ async def step_async( assert seq_group_metadata_list is not None assert scheduler_outputs is not None - assert not (self.scheduler_config.is_multi_step and \ - allow_async_output_proc) - if not scheduler_outputs.is_empty(): finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() @@ -339,22 +338,21 @@ async def step_async( last_sampled_token_ids=last_sampled_token_ids) if allow_async_output_proc: - execute_model_req.async_callback = self.async_callback[ + execute_model_req.async_callback = self.async_callbacks[ virtual_engine] # Execute the model. - output = await self.model_executor.execute_model_async( + outputs = await self.model_executor.execute_model_async( execute_model_req) + # we need to do this here so that last step's sampled_token_ids can # be passed to the next iteration for PP. if self.scheduler_config.is_multi_step: - self._update_cached_scheduler_output(virtual_engine, output) + self._update_cached_scheduler_output(virtual_engine, outputs) else: if len(ctx.output_queue) > 0: - assert not self.scheduler_config.is_multi_step - self._process_model_outputs(virtual_engine=virtual_engine, - is_async=True) - output = [] + self._process_model_outputs(ctx=ctx) + outputs = [] # Finish the current step for all the sequence groups. if self.scheduler_config.is_multi_step: @@ -362,42 +360,42 @@ async def step_async( seq_group.finish_step() if not self._has_remaining_steps(seq_group_metadata_list): - # clear the cache if we have finished all the steps + # Clear the cache if we have finished all the steps if self.scheduler_config.is_multi_step: self.cached_scheduler_outputs[ virtual_engine] = SchedulerOutputState() - # Cache results in engine - ctx.output_queue.append( - (output, seq_group_metadata_list, scheduler_outputs)) + ctx.append_output(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=allow_async_output_proc, + is_last_step=True) - if output and allow_async_output_proc: + if outputs and allow_async_output_proc: assert len( - output - ) == 1, "Multi step decoding does not work with async output processing." # noqa: E501 + outputs + ) == 1, "Async postprocessor expects only a single output set" self._advance_to_next_step( - output[0], seq_group_metadata_list, + outputs[0], seq_group_metadata_list, scheduler_outputs.scheduled_seq_groups) if not allow_async_output_proc: - self._process_model_outputs(virtual_engine=virtual_engine, - is_async=False) + self._process_model_outputs(ctx=ctx) # Log stats. - self.do_log_stats(scheduler_outputs, output) + self.do_log_stats(scheduler_outputs, outputs) # Tracing self.do_tracing(scheduler_outputs) else: - ctx.request_outputs = [] + # Multi-step case + return ctx.request_outputs if not self.has_unfinished_requests(): # Drain async postprocessor (if exists) if len(ctx.output_queue) > 0: - assert not self.scheduler_config.is_multi_step - self._process_model_outputs(virtual_engine=virtual_engine, - is_async=True) + self._process_model_outputs(ctx=ctx) assert len(ctx.output_queue) == 0 return ctx.request_outputs @@ -616,6 +614,17 @@ def __init__(self, self.log_requests = log_requests self.engine = self._init_engine(*args, **kwargs) + # This ensures quick processing of request outputs + # so the append to asyncio queues is not delayed, + # especially for multi-step. + # + # TODO: Currently, disabled for engine_use_ray, ask + # Cody/Will/Woosuk about this case. + self.use_process_request_outputs_callback = not self.engine_use_ray + if self.use_process_request_outputs_callback: + self.engine.process_request_outputs_callback = \ + self.process_request_outputs + if self.engine_use_ray: print_warning_once( "DEPRECATED. `--engine-use-ray` is deprecated and will " @@ -862,13 +871,27 @@ async def engine_step(self, virtual_engine: int) -> bool: request_outputs = await self.engine.step_async(virtual_engine) # Put the outputs into the corresponding streams. - finished = True + # If used as a callback, then already invoked inside + # LLMEngine's _process_model_outputs + if not self.use_process_request_outputs_callback: + all_finished = self.process_request_outputs(request_outputs) + else: + # For callback case, we only need to detect when all + # requests are finished + all_finished = all(request_output.finished + for request_output in request_outputs) + + return not all_finished + + def process_request_outputs(self, request_outputs) -> bool: + # Put the outputs into the corresponding streams. + all_finished = True for request_output in request_outputs: self._request_tracker.process_request_output( request_output, verbose=self.log_requests) - finished = finished and request_output.finished + all_finished = all_finished and request_output.finished - return not finished + return all_finished async def _engine_abort(self, request_ids: Iterable[str]): if self.engine_use_ray: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a8c572495d33b..db9b5a18cc507 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2,9 +2,9 @@ import time from collections import deque from contextlib import contextmanager -from dataclasses import dataclass, field +from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List, - Mapping, Optional) + Mapping, NamedTuple, Optional) from typing import Sequence as GenericSequence from typing import Set, Tuple, Type, Union @@ -33,6 +33,7 @@ from vllm.inputs.parse import is_explicit_encoder_decoder_prompt from vllm.logger import init_logger, print_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MultiModalDataDict from vllm.outputs import (EmbeddingRequestOutput, RequestOutput, RequestOutputFactory) @@ -40,8 +41,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, - SamplerOutput, Sequence, SequenceGroup, - SequenceGroupMetadata, SequenceStatus) + Sequence, SequenceGroup, SequenceGroupMetadata, + SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.config import try_get_generation_config @@ -89,15 +90,36 @@ class SchedulerOutputState: last_output: Optional[SamplerOutput] = None -@dataclass +class OutputData(NamedTuple): + outputs: List[SamplerOutput] + seq_group_metadata_list: List[SequenceGroupMetadata] + scheduler_outputs: SchedulerOutputs + is_async: bool + is_last_step: bool + skip: List[int] + + class SchedulerContext: - output_queue: Deque[Tuple[List[SamplerOutput], List[SequenceGroupMetadata], - SchedulerOutputs]] = field( - default_factory=lambda: deque()) - request_outputs: List[Union[RequestOutput, - EmbeddingRequestOutput]] = field( - default_factory=lambda: []) + def __init__(self): + self.output_queue: Deque[OutputData] = deque() + self.request_outputs: List[Union[RequestOutput, + EmbeddingRequestOutput]] = [] + self.seq_group_metadata_list: Optional[ + List[SequenceGroupMetadata]] = None + self.scheduler_outputs: Optional[SchedulerOutputs] = None + + def append_output(self, outputs: List[SamplerOutput], + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduler_outputs: SchedulerOutputs, is_async: bool, + is_last_step: bool): + self.output_queue.append( + OutputData(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=is_async, + is_last_step=is_last_step, + skip=[])) class LLMEngine: @@ -211,6 +233,7 @@ def __init__( "Initializing an LLM engine (v%s) with config: " "model=%r, speculative_config=%r, tokenizer=%r, " "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " + "override_neuron_config=%s, " "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " @@ -229,6 +252,7 @@ def __init__( model_config.skip_tokenizer_init, model_config.tokenizer_mode, model_config.revision, + model_config.override_neuron_config, model_config.rope_scaling, model_config.rope_theta, model_config.tokenizer_revision, @@ -355,6 +379,26 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: # different process. self.tokenizer.ping() + self.cached_scheduler_outputs = [ + SchedulerOutputState() + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + self.scheduler_contexts = [ + SchedulerContext() + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + + self.async_callbacks = [ + functools.partial(self._process_model_outputs, + ctx=self.scheduler_contexts[v_id]) + for v_id in range(self.parallel_config.pipeline_parallel_size) + ] + + # Currently used by AsyncLLMEngine to ensure quick append + # of request outputs to asyncio queues + self.process_request_outputs_callback = None + # Create the scheduler. # NOTE: the cache_config here have been updated with the numbers of # GPU and CPU blocks, which are profiled in the distributed executor. @@ -362,9 +406,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: Scheduler( scheduler_config, cache_config, lora_config, parallel_config.pipeline_parallel_size, - functools.partial(self._process_model_outputs, - virtual_engine=v_id, - is_async=True) + self.async_callbacks[v_id] if model_config.use_async_output_proc else None) for v_id in range(parallel_config.pipeline_parallel_size) ] @@ -415,23 +457,6 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: ), )) - self.cached_scheduler_outputs = [ - SchedulerOutputState() - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - - self.scheduler_contexts = [ - SchedulerContext() - for _ in range(self.parallel_config.pipeline_parallel_size) - ] - - self.async_callback = [ - functools.partial(self._process_model_outputs, - virtual_engine=v_id, - is_async=True) - for v_id in range(self.parallel_config.pipeline_parallel_size) - ] - def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -1243,28 +1268,30 @@ def _process_sequence_group_outputs( return - def _process_model_outputs(self, virtual_engine: int, - is_async: bool) -> None: - """Apply the model output to the sequences in the scheduled seq groups. + def _process_model_outputs(self, + ctx: SchedulerContext, + request_id: Optional[str] = None) -> None: + """Apply the model output to the sequences in the scheduled seq groups + and return responses. - virtual_engine: The engine id to operate on - is_async: Indicates whether this postprocessor runs in - parallel with the GPU forward pass and is processing - tokens from the previous step. If this is true, then - no tokens need to be appended since it is already done - externally (before the next schedule() call) + ctx: The virtual engine context to work on + request_id: If provided, then only this request is going to be processed - Returns RequestOutputs that can be returned to the client. """ now = time.time() - ctx: SchedulerContext = self.scheduler_contexts[virtual_engine] - if len(ctx.output_queue) == 0: return None - (outputs, seq_group_metadata_list, - scheduler_outputs) = ctx.output_queue.popleft() + # Get pending async postprocessor + if request_id: + # When we process only one request, no pop is required + # (since later we will process all of the rest) + (outputs, seq_group_metadata_list, scheduler_outputs, is_async, + is_last_step, skip) = ctx.output_queue[0] + else: + (outputs, seq_group_metadata_list, scheduler_outputs, is_async, + is_last_step, skip) = ctx.output_queue.popleft() # Sanity check assert len(seq_group_metadata_list) == len( @@ -1278,8 +1305,30 @@ def _process_model_outputs(self, virtual_engine: int, else: outputs_by_sequence_group = outputs + # Determine the requests we need to operate on + if request_id: + indices = [] + for i, seq_group_meta in enumerate(seq_group_metadata_list): + if seq_group_meta.request_id == request_id: + assert i not in skip # Cannot be called twice + indices.append(i) + break + + # If the request_id was not found, then it means that + # this is a new request that has no pending async + # postprocessor + if not indices: + return + else: + indices = range(len(seq_group_metadata_list)) # type: ignore + finished_before: List[int] = [] - for i, seq_group_meta in enumerate(seq_group_metadata_list): + finished_now: List[int] = [] + for i in indices: + if i in skip: + continue + + seq_group_meta = seq_group_metadata_list[i] scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] seq_group = scheduled_seq_group.seq_group @@ -1316,24 +1365,57 @@ def _process_model_outputs(self, virtual_engine: int, if self.model_config.embedding_mode: self._process_sequence_group_outputs(seq_group, output) - continue - - self.output_processor.process_prompt_logprob(seq_group, output) - if seq_group_meta.do_sample: - self.output_processor.process_outputs(seq_group, output, - is_async) + else: + self.output_processor.process_prompt_logprob(seq_group, output) + if seq_group_meta.do_sample: + self.output_processor.process_outputs( + seq_group, output, is_async) - # Free the finished sequence groups. - for scheduler in self.scheduler: - scheduler.free_finished_seq_groups() + if seq_group.is_finished(): + finished_now.append(i) - # Create the outputs. - for i, _ in enumerate(seq_group_metadata_list): + # Generate outputs for the requests that finished this iteration + for i in finished_now: scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - if i in finished_before: + seq_group = scheduled_seq_group.seq_group + seq_group.maybe_set_first_token_time(now) + request_output = RequestOutputFactory.create(seq_group) + ctx.request_outputs.append(request_output) + + # When we process a single request, we skip it for the next time, + # and invoke the request output callback (if there was final output) + if request_id: + assert len(indices) == 1 + skip.append(indices[0]) + + if (finished_now + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + + # Free currently finished requests + if finished_now: + for scheduler in self.scheduler: + scheduler.free_finished_seq_groups() + + # For multi-step, do not create outputs each iteration + if not is_last_step: + # Immediately process request outputs here (if callback is given) + if (finished_now + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + return + + # Create the outputs + for i in indices: + if i in skip or i in finished_before or i in finished_now: continue # Avoids double processing + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) if (seq_group.is_finished() @@ -1345,6 +1427,15 @@ def _process_model_outputs(self, virtual_engine: int, request_output = RequestOutputFactory.create(seq_group) ctx.request_outputs.append(request_output) + # Immediately process request outputs here (if callback is given) + if (ctx.request_outputs + and self.process_request_outputs_callback is not None): + self.process_request_outputs_callback(ctx.request_outputs) + ctx.request_outputs.clear() + + # For async case, we need to record the stats here. + # For non-async case, the stats are done in the + # LLMEngine/AsyncLLMEngine directly if is_async: # Log stats. self.do_log_stats(scheduler_outputs, outputs, finished_before) @@ -1440,7 +1531,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: "as performance will be severely degraded otherwise.") # For llm_engine, there is no pipeline parallel support, so the engine - # used is always 0 + # used is always 0. virtual_engine = 0 # These are cached outputs from previous iterations. None if on first @@ -1452,23 +1543,24 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: ctx = self.scheduler_contexts[virtual_engine] + # Clear outputs for each new scheduler iteration + ctx.request_outputs.clear() + # Skip the scheduler if there are any remaining steps in the seq groups. # This ensures that the scheduler is only called again when the current # batch has completed. if not self._has_remaining_steps(seq_group_metadata_list): - - # Clear outputs on scheduler iteration start - ctx.request_outputs.clear() - # Schedule iteration (seq_group_metadata_list, scheduler_outputs, allow_async_output_proc ) = self.scheduler[virtual_engine].schedule() + ctx.seq_group_metadata_list = seq_group_metadata_list + ctx.scheduler_outputs = scheduler_outputs + # Maybe switch from async mode to sync mode if not allow_async_output_proc and len(ctx.output_queue) > 0: - self._process_model_outputs(virtual_engine=virtual_engine, - is_async=True) + self._process_model_outputs(ctx=ctx) if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): @@ -1481,9 +1573,6 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: assert seq_group_metadata_list is not None assert scheduler_outputs is not None - assert not (self.scheduler_config.is_multi_step and \ - allow_async_output_proc) - if not scheduler_outputs.is_empty(): finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() @@ -1508,25 +1597,23 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: last_sampled_token_ids=last_sampled_token_ids) if allow_async_output_proc: - execute_model_req.async_callback = self.async_callback[ + execute_model_req.async_callback = self.async_callbacks[ virtual_engine] - output = self.model_executor.execute_model( + outputs = self.model_executor.execute_model( execute_model_req=execute_model_req) # We need to do this here so that last step's sampled_token_ids can # be passed to the next iteration for PP. if self.scheduler_config.is_multi_step: - self._update_cached_scheduler_output(virtual_engine, output) + self._update_cached_scheduler_output(virtual_engine, outputs) else: # Nothing scheduled => If there is pending async postprocessor, # then finish it here. if len(ctx.output_queue) > 0: - assert not self.scheduler_config.is_multi_step - self._process_model_outputs(virtual_engine=virtual_engine, - is_async=True) + self._process_model_outputs(ctx=ctx) # No outputs in this case - output = [] + outputs = [] # Finish the current step for all the sequence groups. if self.scheduler_config.is_multi_step: @@ -1539,38 +1626,37 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: self.cached_scheduler_outputs[0] = SchedulerOutputState() # Add results to the output_queue - # (for async or non-async postprocessing) - ctx.output_queue.append( - (output, seq_group_metadata_list, scheduler_outputs)) + ctx.append_output(outputs=outputs, + seq_group_metadata_list=seq_group_metadata_list, + scheduler_outputs=scheduler_outputs, + is_async=allow_async_output_proc, + is_last_step=True) - if output and allow_async_output_proc: - assert len(output) == 1, ("Multi step decoding does not work " - "with async output processing.") + if outputs and allow_async_output_proc: + assert len(outputs) == 1, ( + "Async postprocessor expects only a single output set") self._advance_to_next_step( - output[0], seq_group_metadata_list, + outputs[0], seq_group_metadata_list, scheduler_outputs.scheduled_seq_groups) # Check if need to run the usual non-async path if not allow_async_output_proc: - self._process_model_outputs(virtual_engine=virtual_engine, - is_async=False) + self._process_model_outputs(ctx=ctx) # Log stats. - self.do_log_stats(scheduler_outputs, output) + self.do_log_stats(scheduler_outputs, outputs) # Tracing self.do_tracing(scheduler_outputs) else: # Multi-step case - ctx.request_outputs = [] + return ctx.request_outputs if not self.has_unfinished_requests(): # Drain async postprocessor (if exists) if len(ctx.output_queue) > 0: - assert not self.scheduler_config.is_multi_step - self._process_model_outputs(virtual_engine=virtual_engine, - is_async=True) + self._process_model_outputs(ctx=ctx) assert len(ctx.output_queue) == 0 # Stop the execute model loop in parallel workers until there are @@ -1640,11 +1726,19 @@ def _get_last_sampled_token_ids( return None def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: + if not self.log_stats: + raise RuntimeError( + "Stat logging is disabled. Set `disable_log_stats=False` " + "argument to enable.") if logger_name in self.stat_loggers: raise KeyError(f"Logger with name {logger_name} already exists.") self.stat_loggers[logger_name] = logger def remove_logger(self, logger_name: str) -> None: + if not self.log_stats: + raise RuntimeError( + "Stat logging is disabled. Set `disable_log_stats=False` " + "argument to enable.") if logger_name not in self.stat_loggers: raise KeyError(f"Logger with name {logger_name} does not exist.") del self.stat_loggers[logger_name] @@ -1873,6 +1967,12 @@ def check_health(self) -> None: self.tokenizer.check_health() self.model_executor.check_health() + def start_profile(self) -> None: + self.model_executor.start_profile() + + def stop_profile(self) -> None: + self.model_executor.stop_profile() + def is_tracing_enabled(self) -> bool: return self.tracer is not None @@ -1952,7 +2052,26 @@ def is_embedding_model(self): def _validate_model_inputs(self, inputs: Union[LLMInputs, EncoderDecoderLLMInputs]): - prompt_key = "encoder_prompt_token_ids" \ - if self.is_encoder_decoder_model() else "prompt_token_ids" - if not inputs.get(prompt_key): + if self.is_encoder_decoder_model(): + prompt_ids = inputs.get("encoder_prompt_token_ids") + else: + prompt_ids = inputs.get("prompt_token_ids") + + if prompt_ids is None or len(prompt_ids) == 0: raise ValueError("Prompt cannot be empty") + + if self.model_config.is_multimodal_model: + max_prompt_len = self.model_config.max_model_len + + if len(prompt_ids) > max_prompt_len: + raise ValueError( + f"The prompt (total length {len(prompt_ids)}) is too long " + f"to fit into the model (context length {max_prompt_len}). " + "Make sure that `max_model_len` is no smaller than the " + "number of text tokens plus multimodal tokens. For image " + "inputs, the number of image tokens depends on the number " + "of images, and possibly their aspect ratios as well.") + + # TODO: Find out how many placeholder tokens are there so we can + # check that chunked prefill does not truncate them + # max_batch_len = self.scheduler_config.max_num_batched_tokens diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 49a33ded5fcaa..c73db765fc3b5 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -4,6 +4,8 @@ from vllm.core.scheduler import Scheduler from vllm.engine.output_processor.interfaces import ( SequenceGroupOutputProcessor) +from vllm.engine.output_processor.single_step import ( + single_step_process_prompt_logprob) from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger from vllm.sampling_params import SamplingParams @@ -46,9 +48,16 @@ def __init__( def process_prompt_logprob(self, seq_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: - # TODO(sang): Prompt logprob currently not implemented in multi step - # workers. - self._log_prompt_logprob_unsupported_warning_once() + """Process prompt logprobs associated with each step of a multi-step- + scheduled computation. + + Args: + seq_group: the outputs are associated with this :class:`SequenceGroup` + outputs: the :class:`SequenceGroupOutput`s for all scheduler steps + """ + for output in outputs: + # Concatenate single-step prompt logprob processing results. + single_step_process_prompt_logprob(self, seq_group, output) @staticmethod @functools.lru_cache() @@ -76,29 +85,54 @@ def process_outputs(self, no tokens need to be appended since it is already done externally (before the next schedule() call) """ - # TODO: Add support for async if necessary - assert not is_async - + # Sequences can be in RUNNING or FINISHED_ABORTED state + # once scheduled, as a sequence is moved to FINSIHED_ABORTED + # if a client disconnects from the api server. seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING) + if seqs is None: + seqs = sequence_group.get_seqs( + status=SequenceStatus.FINISHED_ABORTED) - assert seqs, "expected running sequences" + assert seqs, "Expected RUNNING or FINISHED_ABORTED sequences" assert len(seqs) == 1, ( "Beam search not supported in multi-step decoding.") seq = seqs[0] - # Since there's only one sequence per sequence group, we can take the - # first sample. - samples = [output.samples[0] for output in outputs] - - # -1 means the output token is not valid (eg. due to spec decode - # rejecting tokens). - valid_samples = [ - sample for sample in samples if sample.output_token != -1 - ] - assert valid_samples - - self._process_seq_outputs(seq, valid_samples, - sequence_group.sampling_params) + if is_async: + # Async case: We process tokens one by one. Here, we know the token + # was already appended, so we only need to do the rest of the + # postprocessor: Detokenization + stopping logic + self._process_decode_and_stop(seq, sequence_group.sampling_params) + else: + # Standard multi-step case + + # Since there's only one sequence per sequence group, + # we can take the first sample. + samples = [output.samples[0] for output in outputs] + + # -1 means the output token is not valid (eg. due to spec decode + # rejecting tokens). + valid_samples = [ + sample for sample in samples if sample.output_token != -1 + ] + assert valid_samples + + self._process_seq_outputs(seq, valid_samples, + sequence_group.sampling_params) + + def _process_decode_and_stop(self, seq: Sequence, + sampling_params: SamplingParams) -> None: + new_char_count = 0 + if sampling_params.detokenize: + new_char_count = self.detokenizer.decode_sequence_inplace( + seq, sampling_params) + + # TODO(sang): Support lora. + self.stop_checker.maybe_stop_sequence( + seq, + new_char_count=new_char_count, + sampling_params=sampling_params, + ) def _process_seq_outputs(self, seq: Sequence, valid_samples: List[SequenceOutput], @@ -136,16 +170,7 @@ def _process_seq_outputs(self, seq: Sequence, logprobs=output_logprob, ) - new_char_count = 0 - if sampling_params.detokenize: - new_char_count = self.detokenizer.decode_sequence_inplace( - seq, sampling_params) + self._process_decode_and_stop(seq, sampling_params) - # TODO(sang): Support lora. - self.stop_checker.maybe_stop_sequence( - seq, - new_char_count=new_char_count, - sampling_params=sampling_params, - ) if seq.is_finished(): break diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 4b0c3f37a5e21..e288aa0c4aafd 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -15,6 +15,44 @@ logger = init_logger(__name__) +def single_step_process_prompt_logprob( + sg_output_proc: SequenceGroupOutputProcessor, seq_group: SequenceGroup, + output: SequenceGroupOutput) -> None: + """Process prompt logprobs associated with the :class:`SequenceGroupOutput` + for a given step. + + Do nothing if the output has no prompt logprobs. + + Account for the fact that transformers do not compute first-token logprobs. + + Args: + sg_output_proc: :class:`SequenceGroupOutputProcessor` instance + seq_group: the output is associated with this :class:`SequenceGroup` + output: the :class:`SequenceGroupOutput` for a single scheduler step + """ + prompt_logprobs = output.prompt_logprobs + + # If this is the first (or only) "chunk" of the prefill, we need + # to prepend None to the list of prompt logprobs. The reason for this + # is that for N prompt tokens, the Sampler will generate N-1 total + # prompt logprobs during prefill since the token at idx 0 will not + # have a logprob associated with it. + if prompt_logprobs is not None: + if not seq_group.prompt_logprobs: + prompt_logprobs = [None] + prompt_logprobs + seq_group.prompt_logprobs = [] + + assert hasattr(sg_output_proc, 'detokenizer') + if (seq_group.sampling_params.detokenize + and sg_output_proc.detokenizer): + sg_output_proc.detokenizer.decode_prompt_logprobs_inplace( + seq_group, + prompt_logprobs, + position_offset=len(seq_group.prompt_logprobs)) + + seq_group.prompt_logprobs.extend(prompt_logprobs) + + class SingleStepOutputProcessor(SequenceGroupOutputProcessor): """SequenceGroupOutputProcessor which handles "output processing" logic, which happens after the model returns generated token ids and before @@ -60,33 +98,22 @@ def process_outputs(self, sequence_group: SequenceGroup, def process_prompt_logprob(self, seq_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: + """Process prompt logprobs associated with one step of a single-step- + scheduled computation. + + Args: + seq_group: the output is associated with this :class:`SequenceGroup` + output: the :class:`SequenceGroupOutput` for a single scheduler step + """ assert len(outputs) == 1, ("Single step should only has 1 output.") output = outputs[0] - prompt_logprobs = output.prompt_logprobs - - # If this is the first (or only) "chunk" of the prefill, we need - # to prepend None to the list of prompt logprobs. The reason for this - # is that for N prompt tokens, the Sampler will generate N-1 total - # prompt logprobs during prefill since the token at idx 0 will not - # have a logprob associated with it. - if prompt_logprobs is not None: - if not seq_group.prompt_logprobs: - prompt_logprobs = [None] + prompt_logprobs - seq_group.prompt_logprobs = [] - - if seq_group.sampling_params.detokenize and self.detokenizer: - self.detokenizer.decode_prompt_logprobs_inplace( - seq_group, - prompt_logprobs, - position_offset=len(seq_group.prompt_logprobs)) - - seq_group.prompt_logprobs.extend(prompt_logprobs) + single_step_process_prompt_logprob(self, seq_group, output) def _process_sequence_group_outputs(self, seq_group: SequenceGroup, outputs: SequenceGroupOutput, is_async: bool) -> None: sampling_params = seq_group.sampling_params - if sampling_params.n == 1 and not sampling_params.use_beam_search: + if sampling_params.best_of == 1 and not sampling_params.use_beam_search: # only have one output sample sample = outputs.samples[0] # only have one sequence diff --git a/vllm/engine/output_processor/util.py b/vllm/engine/output_processor/util.py index 57cc33d911183..76782888031e3 100644 --- a/vllm/engine/output_processor/util.py +++ b/vllm/engine/output_processor/util.py @@ -2,7 +2,8 @@ from typing import Sequence as GenericSequence from typing import Union -from vllm.sequence import PoolerOutput, SamplerOutput, SequenceGroupOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import PoolerOutput, SequenceGroupOutput def create_output_by_sequence_group( diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index 1deb75167bc72..34ae79f5fa8df 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -5,11 +5,11 @@ from vllm.core.scheduler import SchedulerOutputs from vllm.inputs.data import PromptInputs from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.outputs import EmbeddingRequestOutput, RequestOutput from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams -from vllm.sequence import SamplerOutput from vllm.transformers_utils.tokenizer import AnyTokenizer diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index c5368ac3bf026..f9f9536a7c160 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -1,28 +1,36 @@ +import asyncio import codecs -from dataclasses import dataclass -from functools import lru_cache +import json +from abc import ABC, abstractmethod +from collections import defaultdict +from functools import lru_cache, partial from pathlib import Path -from typing import (Any, Awaitable, Iterable, List, Literal, Optional, Tuple, - Union) +from typing import (Any, Awaitable, Dict, Generic, Iterable, List, Literal, + Mapping, Optional, Tuple, TypeVar, Union, cast) # yapf conflicts with isort for this block # yapf: disable -from openai.types.chat import ChatCompletionContentPartImageParam +from openai.types.chat import (ChatCompletionAssistantMessageParam, + ChatCompletionContentPartImageParam) from openai.types.chat import ( ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam) -from openai.types.chat import ChatCompletionContentPartTextParam +from openai.types.chat import (ChatCompletionContentPartRefusalParam, + ChatCompletionContentPartTextParam) from openai.types.chat import ( ChatCompletionMessageParam as OpenAIChatCompletionMessageParam) +from openai.types.chat import (ChatCompletionMessageToolCallParam, + ChatCompletionToolMessageParam) # yapf: enable # pydantic needs the TypedDict from typing_extensions -from pydantic import ConfigDict, TypeAdapter +from pydantic import ConfigDict from typing_extensions import Required, TypeAlias, TypedDict from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import (async_get_and_parse_audio, - async_get_and_parse_image) + async_get_and_parse_image, + get_and_parse_audio, get_and_parse_image) from vllm.transformers_utils.tokenizer import AnyTokenizer logger = init_logger(__name__) @@ -51,7 +59,8 @@ class CustomChatCompletionContentPartParam(TypedDict, total=False): ChatCompletionContentPartParam: TypeAlias = Union[ OpenAIChatCompletionContentPartParam, ChatCompletionContentPartAudioParam, - CustomChatCompletionContentPartParam, ] + ChatCompletionContentPartRefusalParam, + CustomChatCompletionContentPartParam] class CustomChatCompletionMessageParam(TypedDict, total=False): @@ -69,21 +78,219 @@ class CustomChatCompletionMessageParam(TypedDict, total=False): same role. """ + tool_call_id: Optional[str] + """Tool call that this message is responding to.""" + + tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + """The tool calls generated by the model, such as function calls.""" + ChatCompletionMessageParam = Union[OpenAIChatCompletionMessageParam, CustomChatCompletionMessageParam] # TODO: Make fields ReadOnly once mypy supports it -class ConversationMessage(TypedDict): - role: str - content: str +class ConversationMessage(TypedDict, total=False): + role: Required[str] + """The role of the message's author.""" + + content: Optional[str] + """The contents of the message""" + + tool_call_id: Optional[str] + """Tool call that this message is responding to.""" + + name: Optional[str] + """The name of the function to call""" + + tool_calls: Optional[Iterable[ChatCompletionMessageToolCallParam]] + """The tool calls generated by the model, such as function calls.""" + + +ModalityStr = Literal["image", "audio"] +_T = TypeVar("_T") + + +class BaseMultiModalItemTracker(ABC, Generic[_T]): + """ + Tracks multi-modal items in a given request and ensures that the number + of multi-modal items in a given request does not exceed the configured + maximum per prompt. + """ + + def __init__(self, model_config: ModelConfig, tokenizer: AnyTokenizer): + super().__init__() + + self._model_config = model_config + self._tokenizer = tokenizer + self._allowed_items = (model_config.multimodal_config.limit_per_prompt + if model_config.multimodal_config else {}) + self._consumed_items = {k: 0 for k in self._allowed_items} + + self._items: List[_T] = [] + + @staticmethod + @lru_cache(maxsize=None) + def _cached_token_str(tokenizer: AnyTokenizer, token_index: int) -> str: + return tokenizer.decode(token_index) + + def _placeholder_str(self, modality: ModalityStr, + current_count: int) -> Optional[str]: + # TODO: Let user specify how to insert image tokens into prompt + # (similar to chat template) + hf_config = self._model_config.hf_config + model_type = hf_config.model_type + + if modality == "image": + if model_type == "phi3_v": + # Workaround since this token is not defined in the tokenizer + return f"<|image_{current_count}|>" + if model_type == "minicpmv": + return "(./)" + if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"): + # These models do not use image tokens in the prompt + return None + if model_type == "qwen": + return f"Picture {current_count}: " + if model_type.startswith("llava"): + return self._cached_token_str(self._tokenizer, + hf_config.image_token_index) + if model_type in ("chameleon", "internvl_chat"): + return "" + + raise TypeError(f"Unknown model type: {model_type}") + elif modality == "audio": + if model_type == "ultravox": + return "<|reserved_special_token_0|>" + raise TypeError(f"Unknown model type: {model_type}") + else: + raise TypeError(f"Unknown modality: {modality}") + + @staticmethod + def _combine(items: List[MultiModalDataDict]) -> MultiModalDataDict: + mm_lists: Mapping[str, List[object]] = defaultdict(list) + + # Merge all the multi-modal items + for single_mm_data in items: + for mm_key, mm_item in single_mm_data.items(): + if isinstance(mm_item, list): + mm_lists[mm_key].extend(mm_item) + else: + mm_lists[mm_key].append(mm_item) + + # Unpack any single item lists for models that don't expect multiple. + return { + mm_key: mm_list[0] if len(mm_list) == 1 else mm_list + for mm_key, mm_list in mm_lists.items() + } + + def add(self, modality: ModalityStr, item: _T) -> Optional[str]: + """ + Add a multi-modal item to the current prompt and returns the + placeholder string to use, if any. + """ + allowed_count = self._allowed_items.get(modality, 1) + current_count = self._consumed_items.get(modality, 0) + 1 + if current_count > allowed_count: + raise ValueError( + f"At most {allowed_count} {modality}(s) may be provided in " + "one request.") + + self._consumed_items[modality] = current_count + self._items.append(item) + + return self._placeholder_str(modality, current_count) + + @abstractmethod + def create_parser(self) -> "BaseMultiModalContentParser": + raise NotImplementedError + + +class MultiModalItemTracker(BaseMultiModalItemTracker[MultiModalDataDict]): + + def all_mm_data(self) -> Optional[MultiModalDataDict]: + return self._combine(self._items) if self._items else None + + def create_parser(self) -> "BaseMultiModalContentParser": + return MultiModalContentParser(self) + + +class AsyncMultiModalItemTracker( + BaseMultiModalItemTracker[Awaitable[MultiModalDataDict]]): + + async def all_mm_data(self) -> Optional[MultiModalDataDict]: + if self._items: + items = await asyncio.gather(*self._items) + return self._combine(items) + + return None + + def create_parser(self) -> "BaseMultiModalContentParser": + return AsyncMultiModalContentParser(self) + + +class BaseMultiModalContentParser(ABC): + + def __init__(self) -> None: + super().__init__() + + # multimodal placeholder_string : count + self._placeholder_counts: Dict[str, int] = defaultdict(lambda: 0) + + def _add_placeholder(self, placeholder: Optional[str]): + if placeholder: + self._placeholder_counts[placeholder] += 1 + + def mm_placeholder_counts(self) -> Dict[str, int]: + return dict(self._placeholder_counts) + + @abstractmethod + def parse_image(self, image_url: str) -> None: + raise NotImplementedError + + @abstractmethod + def parse_audio(self, audio_url: str) -> None: + raise NotImplementedError + +class MultiModalContentParser(BaseMultiModalContentParser): -@dataclass(frozen=True) -class ChatMessageParseResult: - messages: List[ConversationMessage] - mm_futures: List[Awaitable[MultiModalDataDict]] + def __init__(self, tracker: MultiModalItemTracker) -> None: + super().__init__() + + self._tracker = tracker + + def parse_image(self, image_url: str) -> None: + image = get_and_parse_image(image_url) + + placeholder = self._tracker.add("image", image) + self._add_placeholder(placeholder) + + def parse_audio(self, audio_url: str) -> None: + audio = get_and_parse_audio(audio_url) + + placeholder = self._tracker.add("audio", audio) + self._add_placeholder(placeholder) + + +class AsyncMultiModalContentParser(BaseMultiModalContentParser): + + def __init__(self, tracker: AsyncMultiModalItemTracker) -> None: + super().__init__() + + self._tracker = tracker + + def parse_image(self, image_url: str) -> None: + image_coro = async_get_and_parse_image(image_url) + + placeholder = self._tracker.add("image", image_coro) + self._add_placeholder(placeholder) + + def parse_audio(self, audio_url: str) -> None: + audio_coro = async_get_and_parse_audio(audio_url) + + placeholder = self._tracker.add("audio", audio_coro) + self._add_placeholder(placeholder) def load_chat_template( @@ -112,152 +319,150 @@ def load_chat_template( return resolved_chat_template -@lru_cache(maxsize=None) -def _mm_token_str(model_config: ModelConfig, tokenizer: AnyTokenizer, - modality: Literal["image", "audio"]) -> Optional[str]: - # TODO: Let user specify how to insert image tokens into prompt - # (similar to chat template) - model_type = model_config.hf_config.model_type - if modality == "image": - if model_type == "phi3_v": - # Workaround since this token is not defined in the tokenizer - return "<|image_1|>" - if model_type == "minicpmv": - return "(./)" - if model_type in ("blip-2", "chatglm", "fuyu", "paligemma"): - # These models do not use image tokens in the prompt - return None - if model_type.startswith("llava"): - return tokenizer.decode(model_config.hf_config.image_token_index) - if model_type in ("chameleon", "internvl_chat"): - return "" - - raise TypeError(f"Unknown model type: {model_type}") - elif modality == "audio": - if model_type == "ultravox": - return "<|reserved_special_token_0|>" - raise TypeError(f"Unknown model type: {model_type}") - else: - raise TypeError(f"Unknown modality: {modality}") - - # TODO: Let user specify how to insert multimodal tokens into prompt # (similar to chat template) -def _get_full_multimodal_text_prompt(placeholder_token_str: str, +def _get_full_multimodal_text_prompt(placeholder_counts: Dict[str, int], text_prompt: str) -> str: - """Combine multimodal prompts for a multimodal language model""" + """Combine multimodal prompts for a multimodal language model.""" + + # Look through the text prompt to check for missing placeholders + missing_placeholders: List[str] = [] + for placeholder in placeholder_counts: - # NOTE: For now we assume all model architectures use the same - # placeholder + text prompt format. This may change in the future. - return f"{placeholder_token_str}\n{text_prompt}" + # For any existing placeholder in the text prompt, we leave it as is + placeholder_counts[placeholder] -= text_prompt.count(placeholder) + if placeholder_counts[placeholder] < 0: + raise ValueError( + f"Found more '{placeholder}' placeholders in input prompt than " + "actual multimodal data items.") -_TextParser = TypeAdapter(ChatCompletionContentPartTextParam) -_ImageParser = TypeAdapter(ChatCompletionContentPartImageParam) -_AudioParser = TypeAdapter(ChatCompletionContentPartAudioParam) + missing_placeholders.extend([placeholder] * + placeholder_counts[placeholder]) + + # NOTE: For now we always add missing placeholders at the front of + # the prompt. This may change to be customizable in the future. + return "\n".join(missing_placeholders + [text_prompt]) + + +# No need to validate using Pydantic again +_TextParser = partial(cast, ChatCompletionContentPartTextParam) +_ImageParser = partial(cast, ChatCompletionContentPartImageParam) +_AudioParser = partial(cast, ChatCompletionContentPartAudioParam) +_RefusalParser = partial(cast, ChatCompletionContentPartRefusalParam) def _parse_chat_message_content_parts( role: str, parts: Iterable[ChatCompletionContentPartParam], - model_config: ModelConfig, - tokenizer: AnyTokenizer, -) -> ChatMessageParseResult: + mm_tracker: BaseMultiModalItemTracker, +) -> List[ConversationMessage]: texts: List[str] = [] - mm_futures: List[Awaitable[MultiModalDataDict]] = [] - modality: Literal["image", "audio"] = "image" + + mm_parser = mm_tracker.create_parser() for part in parts: part_type = part["type"] if part_type == "text": - text = _TextParser.validate_python(part)["text"] + text = _TextParser(part)["text"] texts.append(text) elif part_type == "image_url": - modality = "image" - if len(mm_futures) > 0: - raise NotImplementedError( - "Multiple multimodal inputs is currently not supported.") - - image_url = _ImageParser.validate_python(part)["image_url"] + image_url = _ImageParser(part)["image_url"] if image_url.get("detail", "auto") != "auto": logger.warning( "'image_url.detail' is currently not supported and " "will be ignored.") - image_future = async_get_and_parse_image(image_url["url"]) - mm_futures.append(image_future) + mm_parser.parse_image(image_url["url"]) elif part_type == "audio_url": - modality = "audio" - if len(mm_futures) > 0: - raise NotImplementedError( - "Multiple multimodal inputs is currently not supported.") - - audio_url = _AudioParser.validate_python(part)["audio_url"] - audio_future = async_get_and_parse_audio(audio_url["url"]) - mm_futures.append(audio_future) + audio_url = _AudioParser(part)["audio_url"] + + mm_parser.parse_audio(audio_url["url"]) else: raise NotImplementedError(f"Unknown part type: {part_type}") text_prompt = "\n".join(texts) + mm_placeholder_counts = mm_parser.mm_placeholder_counts() + if mm_placeholder_counts: + text_prompt = _get_full_multimodal_text_prompt(mm_placeholder_counts, + text_prompt) - if mm_futures: - placeholder_token_str = _mm_token_str(model_config, tokenizer, - modality) - if placeholder_token_str is not None: - if placeholder_token_str in text_prompt: - logger.warning( - "Detected multi-modal token string in the text prompt. " - "Skipping prompt formatting.") - else: - text_prompt = _get_full_multimodal_text_prompt( - placeholder_token_str=placeholder_token_str, - text_prompt=text_prompt, - ) + return [ConversationMessage(role=role, content=text_prompt)] - messages = [ConversationMessage(role=role, content=text_prompt)] - return ChatMessageParseResult(messages=messages, mm_futures=mm_futures) +# No need to validate using Pydantic again +_AssistantParser = partial(cast, ChatCompletionAssistantMessageParam) +_ToolParser = partial(cast, ChatCompletionToolMessageParam) def _parse_chat_message_content( message: ChatCompletionMessageParam, - model_config: ModelConfig, - tokenizer: AnyTokenizer, -) -> ChatMessageParseResult: + mm_tracker: BaseMultiModalItemTracker, +) -> List[ConversationMessage]: role = message["role"] content = message.get("content") if content is None: - return ChatMessageParseResult(messages=[], mm_futures=[]) - if isinstance(content, str): - messages = [ConversationMessage(role=role, content=content)] - return ChatMessageParseResult(messages=messages, mm_futures=[]) + content = [] + elif isinstance(content, str): + content = [ + ChatCompletionContentPartTextParam(type="text", text=content) + ] - return _parse_chat_message_content_parts( + result = _parse_chat_message_content_parts( role, content, # type: ignore - model_config, - tokenizer, + mm_tracker, ) + for result_msg in result: + if role == 'assistant': + parsed_msg = _AssistantParser(message) + + if "tool_calls" in parsed_msg: + result_msg["tool_calls"] = list(parsed_msg["tool_calls"]) + elif role == "tool": + parsed_msg = _ToolParser(message) + if "tool_call_id" in parsed_msg: + result_msg["tool_call_id"] = parsed_msg["tool_call_id"] + + if "name" in message and isinstance(message["name"], str): + result_msg["name"] = message["name"] + + return result + def parse_chat_messages( messages: List[ChatCompletionMessageParam], model_config: ModelConfig, tokenizer: AnyTokenizer, -) -> Tuple[List[ConversationMessage], List[Awaitable[MultiModalDataDict]]]: +) -> Tuple[List[ConversationMessage], Optional[MultiModalDataDict]]: conversation: List[ConversationMessage] = [] - mm_futures: List[Awaitable[MultiModalDataDict]] = [] + mm_tracker = MultiModalItemTracker(model_config, tokenizer) for msg in messages: - parse_result = _parse_chat_message_content(msg, model_config, - tokenizer) + sub_messages = _parse_chat_message_content(msg, mm_tracker) - conversation.extend(parse_result.messages) - mm_futures.extend(parse_result.mm_futures) + conversation.extend(sub_messages) - return conversation, mm_futures + return conversation, mm_tracker.all_mm_data() + + +def parse_chat_messages_futures( + messages: List[ChatCompletionMessageParam], + model_config: ModelConfig, + tokenizer: AnyTokenizer, +) -> Tuple[List[ConversationMessage], Awaitable[Optional[MultiModalDataDict]]]: + conversation: List[ConversationMessage] = [] + mm_tracker = AsyncMultiModalItemTracker(model_config, tokenizer) + + for msg in messages: + sub_messages = _parse_chat_message_content(msg, mm_tracker) + + conversation.extend(sub_messages) + + return conversation, mm_tracker.all_mm_data() def apply_chat_template( @@ -274,6 +479,20 @@ def apply_chat_template( "allowed, so you must provide a chat template if the tokenizer " "does not define one.") + # per the Transformers docs & maintainers, tool call arguments in + # assistant-role messages with tool_calls need to be dicts not JSON str - + # this is how tool-use chat templates will expect them moving forwards + # so, for messages that have tool_calls, parse the string (which we get + # from openAI format) to dict + for message in conversation: + if (message["role"] == "assistant" and "tool_calls" in message + and isinstance(message["tool_calls"], list)): + + for i in range(len(message["tool_calls"])): + args: str = message["tool_calls"][i]["function"]["arguments"] + parsed_args: Dict = json.loads(args) + message["tool_calls"][i]["function"]["arguments"] = parsed_args + prompt = tokenizer.apply_chat_template( conversation=conversation, chat_template=chat_template, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 0edd4bfaecd6a..1e4432eaaa665 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -23,7 +23,7 @@ get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.usage.usage_lib import UsageContext -from vllm.utils import Counter, deprecate_kwargs +from vllm.utils import Counter, deprecate_kwargs, is_list_of logger = init_logger(__name__) @@ -55,7 +55,7 @@ class LLM: However, if the `torch_dtype` in the config is `float32`, we will use `float16` instead. quantization: The method used to quantize the model weights. Currently, - we support "awq", "gptq", "squeezellm", and "fp8" (experimental). + we support "awq", "gptq", and "fp8" (experimental). If None, we first check the `quantization_config` attribute in the model config file. If that is None, we assume the model weights are not quantized and use `dtype` to determine the data type of @@ -358,15 +358,18 @@ def chat( add_generation_prompt: bool = True, ) -> List[RequestOutput]: """ - Generates responses for chat messages. + Generate responses for a chat conversation. - Converts the messages to prompts using the tokenizer and calls - the :meth:`generate` method to generate the responses. + The chat conversation is converted into a text prompt using the + tokenizer and calls the :meth:`generate` method to generate the + responses. + + Multi-modal inputs can be passed in the same way you would pass them + to the OpenAI API. Args: - messages: A list of messages to generate responses for. Each - message is a list of dictionaries with 'role' and 'content' - keys. + messages: A single conversation represented as a list of messages. + Each message is a dictionary with 'role' and 'content' keys. sampling_params: The sampling parameters for text generation. If None, we use the default sampling parameters. When it is a single value, it is applied to every prompt. When it @@ -387,21 +390,25 @@ def chat( tokenizer = self.get_tokenizer() model_config = self.llm_engine.get_model_config() - conversations, _ = parse_chat_messages(messages, model_config, - tokenizer) + conversation, mm_data = parse_chat_messages(messages, model_config, + tokenizer) prompt = apply_chat_template( tokenizer, - conversations, + conversation, chat_template=chat_template, - add_generation_prompt=add_generation_prompt) + add_generation_prompt=add_generation_prompt, + ) inputs: PromptInputs - if isinstance(prompt, list) and isinstance(prompt[0], int): + if is_list_of(prompt, int): inputs = TokensPrompt(prompt_token_ids=prompt) else: inputs = TextPrompt(prompt=prompt) + if mm_data is not None: + inputs["multi_modal_data"] = mm_data + return self.generate( inputs, sampling_params=sampling_params, @@ -553,6 +560,12 @@ def encode( outputs = self._run_engine(use_tqdm=use_tqdm) return LLMEngine.validate_outputs(outputs, EmbeddingRequestOutput) + def start_profile(self) -> None: + self.llm_engine.start_profile() + + def stop_profile(self) -> None: + self.llm_engine.stop_profile() + # LEGACY def _convert_v1_inputs( self, diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 8e8371ef1559a..d8704d5e24964 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -35,11 +35,13 @@ DetokenizeResponse, EmbeddingRequest, EmbeddingResponse, ErrorResponse, + LoadLoraAdapterRequest, TokenizeRequest, - TokenizeResponse) -# yapf: enable + TokenizeResponse, + UnloadLoraAdapterRequest) from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient from vllm.entrypoints.openai.rpc.server import run_rpc_server +# yapf: enable from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding @@ -67,7 +69,7 @@ def model_is_embedding(model_name: str, trust_remote_code: bool, - quantization: str) -> bool: + quantization: Optional[str]) -> bool: return ModelConfig(model=model_name, tokenizer=model_name, tokenizer_mode="auto", @@ -96,13 +98,6 @@ async def _force_log(): @asynccontextmanager async def build_async_engine_client( args: Namespace) -> AsyncIterator[Optional[AsyncEngineClient]]: - """ - Create AsyncEngineClient, either: - - in-process using the AsyncLLMEngine Directly - - multiprocess using AsyncLLMEngine RPC - - Returns the Client or None if the creation failed. - """ # Context manager to handle async_engine_client lifecycle # Ensures everything is shutdown and cleaned up on error/exit @@ -112,14 +107,37 @@ async def build_async_engine_client( # Backend itself still global for the silly lil' health handler global async_engine_client + async with build_async_engine_client_from_engine_args( + engine_args, args.disable_frontend_multiprocessing) as engine: + + async_engine_client = engine # type: ignore[assignment] + yield engine + + +@asynccontextmanager +async def build_async_engine_client_from_engine_args( + engine_args: AsyncEngineArgs, + disable_frontend_multiprocessing: bool = False, +) -> AsyncIterator[Optional[AsyncEngineClient]]: + """ + Create AsyncEngineClient, either: + - in-process using the AsyncLLMEngine Directly + - multiprocess using AsyncLLMEngine RPC + + Returns the Client or None if the creation failed. + """ + # If manually triggered or embedding model, use AsyncLLMEngine in process. # TODO: support embedding model via RPC. - if (model_is_embedding(args.model, args.trust_remote_code, - args.quantization) - or args.disable_frontend_multiprocessing): - async_engine_client = AsyncLLMEngine.from_engine_args( + if (model_is_embedding(engine_args.model, engine_args.trust_remote_code, + engine_args.quantization) + or disable_frontend_multiprocessing): + engine_client = AsyncLLMEngine.from_engine_args( engine_args, usage_context=UsageContext.OPENAI_API_SERVER) - yield async_engine_client + try: + yield engine_client + finally: + engine_client.shutdown_background_loop() return # Otherwise, use the multiprocessing AsyncLLMEngine. @@ -148,7 +166,6 @@ async def build_async_engine_client( # NOTE: Actually, this is not true yet. We still need to support # embedding models via RPC (see TODO above) rpc_client = AsyncEngineRPCClient(rpc_path) - async_engine_client = rpc_client # type: ignore # Start RPCServer in separate process (holds the AsyncLLMEngine). context = multiprocessing.get_context("spawn") @@ -174,7 +191,7 @@ async def build_async_engine_client( yield None return - yield async_engine_client + yield rpc_client # type: ignore[misc] finally: # Ensure rpc server process was terminated rpc_server_process.terminate() @@ -218,7 +235,7 @@ def mount_metrics(app: FastAPI): metrics_route = Mount("/metrics", make_asgi_app()) # Workaround for 307 Redirect for /metrics - metrics_route.path_regex = re.compile('^/metrics(?P.*)$') + metrics_route.path_regex = re.compile("^/metrics(?P.*)$") app.routes.append(metrics_route) @@ -268,11 +285,14 @@ async def show_version(): @router.post("/v1/chat/completions") async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): + generator = await openai_serving_chat.create_chat_completion( request, raw_request) + if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) + elif isinstance(generator, ChatCompletionResponse): return JSONResponse(content=generator.model_dump()) @@ -325,6 +345,40 @@ async def stop_profile(): return Response(status_code=200) +if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING: + logger.warning( + "Lora dynamic loading & unloading is enabled in the API server. " + "This should ONLY be used for local development!") + + @router.post("/v1/load_lora_adapter") + async def load_lora_adapter(request: LoadLoraAdapterRequest): + response = await openai_serving_chat.load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + response = await openai_serving_completion.load_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + return Response(status_code=200, content=response) + + @router.post("/v1/unload_lora_adapter") + async def unload_lora_adapter(request: UnloadLoraAdapterRequest): + response = await openai_serving_chat.unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + response = await openai_serving_completion.unload_lora_adapter(request) + if isinstance(response, ErrorResponse): + return JSONResponse(content=response.model_dump(), + status_code=response.code) + + return Response(status_code=200, content=response) + + def build_app(args: Namespace) -> FastAPI: app = FastAPI(lifespan=lifespan) app.include_router(router) @@ -407,7 +461,8 @@ async def init_app( request_logger=request_logger, chat_template=args.chat_template, return_tokens_as_token_ids=args.return_tokens_as_token_ids, - ) + enable_auto_tools=args.enable_auto_tool_choice, + tool_parser=args.tool_call_parser) openai_serving_completion = OpenAIServingCompletion( async_engine_client, model_config, diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 94742838b421c..7ccee0b6b55b7 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -163,6 +163,24 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: help="If specified, will run the OpenAI frontend server in the same " "process as the model serving engine.") + parser.add_argument( + "--enable-auto-tool-choice", + action="store_true", + default=False, + help= + "Enable auto tool choice for supported models. Use --tool-call-parser" + "to specify which parser to use") + + parser.add_argument( + "--tool-call-parser", + type=str, + choices=["mistral", "hermes"], + default=None, + help= + "Select the tool call parser depending on the model that you're using." + " This is used to parse the model-generated tool call into OpenAI API " + "format. Required for --enable-auto-tool-choice.") + parser = AsyncEngineArgs.add_cli_args(parser) parser.add_argument('--max-log-len', diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0954b81595ef5..374196044b7e8 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -5,8 +5,9 @@ from typing import Any, Dict, List, Literal, Optional, Union import torch +from openai.types.chat import ChatCompletionContentPartParam from pydantic import BaseModel, ConfigDict, Field, model_validator -from typing_extensions import Annotated +from typing_extensions import Annotated, Required, TypedDict from vllm.entrypoints.chat_utils import ChatCompletionMessageParam from vllm.entrypoints.openai.logits_processors import get_logits_processors @@ -35,6 +36,26 @@ assert _LONG_INFO.max == _MOCK_LONG_INFO.max +class CustomChatCompletionMessageParam(TypedDict, total=False): + """Enables custom roles in the Chat Completion API.""" + role: Required[str] + """The role of the message's author.""" + + content: Union[str, List[ChatCompletionContentPartParam]] + """The contents of the message.""" + + name: str + """An optional name for the participant. + + Provides the model information to differentiate between participants of the + same role. + """ + + tool_call_id: Optional[str] + + tool_calls: Optional[List[dict]] + + class OpenAIBaseModel(BaseModel): # OpenAI API does not allow extra fields model_config = ConfigDict(extra="forbid") @@ -145,8 +166,11 @@ class ChatCompletionRequest(OpenAIBaseModel): temperature: Optional[float] = 0.7 top_p: Optional[float] = 1.0 tools: Optional[List[ChatCompletionToolsParam]] = None - tool_choice: Optional[Union[Literal["none"], + tool_choice: Optional[Union[Literal["none"], Literal["auto"], ChatCompletionNamedToolChoiceParam]] = "none" + + # NOTE this will be ignored by VLLM -- the model determines the behavior + parallel_tool_calls: Optional[bool] = False user: Optional[str] = None # doc: begin-chat-completion-sampling-params @@ -328,6 +352,9 @@ def check_logprobs(cls, data): @model_validator(mode="before") @classmethod def check_guided_decoding_count(cls, data): + if isinstance(data, ValueError): + raise data + guide_count = sum([ "guided_json" in data and data["guided_json"] is not None, "guided_regex" in data and data["guided_regex"] is not None, @@ -339,21 +366,61 @@ def check_guided_decoding_count(cls, data): "You can only use one kind of guided decoding " "('guided_json', 'guided_regex' or 'guided_choice').") # you can only either use guided decoding or tools, not both - if guide_count > 1 and "tool_choice" in data and data[ - "tool_choice"] != "none": + if guide_count > 1 and data.get("tool_choice", + "none") not in ("none", "auto"): raise ValueError( "You can only either use guided decoding or tools, not both.") return data @model_validator(mode="before") @classmethod - def check_tool_choice(cls, data): - if "tool_choice" in data and data["tool_choice"] != "none": - if not isinstance(data["tool_choice"], dict): - raise ValueError("Currently only named tools are supported.") + def check_tool_usage(cls, data): + + # if "tool_choice" is not specified but tools are provided, + # default to "auto" tool_choice + if "tool_choice" not in data and "tools" in data: + data["tool_choice"] = "auto" + + # if "tool_choice" is specified -- validation + if "tool_choice" in data: + + # ensure that if "tool choice" is specified, tools are present if "tools" not in data or data["tools"] is None: raise ValueError( "When using `tool_choice`, `tools` must be set.") + + # make sure that tool choice is either a named tool + # OR that it's set to "auto" + if data["tool_choice"] != "auto" and not isinstance( + data["tool_choice"], dict): + raise ValueError( + "`tool_choice` must either be a named tool or \"auto\". " + "`tool_choice=\"none\" is not supported.") + + # ensure that if "tool_choice" is specified as an object, + # it matches a valid tool + if isinstance(data["tool_choice"], dict): + valid_tool = False + specified_function = data["tool_choice"]["function"] + if not specified_function: + raise ValueError( + "Incorrectly formatted `tool_choice`. Should be like " + "`{\"type\": \"function\"," + " \"function\": {\"name\": \"my_function\"}}`") + specified_function_name = specified_function["name"] + if not specified_function_name: + raise ValueError( + "Incorrectly formatted `tool_choice`. Should be like " + "`{\"type\": \"function\", " + "\"function\": {\"name\": \"my_function\"}}`") + for tool in data["tools"]: + if tool["function"]["name"] == specified_function_name: + valid_tool = True + break + if not valid_tool: + raise ValueError( + "The tool specified in `tool_choice` does not match any" + " of the specified `tools`") return data @@ -413,7 +480,7 @@ class CompletionRequest(OpenAIBaseModel): ) guided_json: Optional[Union[str, dict, BaseModel]] = Field( default=None, - description=("If specified, the output will follow the JSON schema."), + description="If specified, the output will follow the JSON schema.", ) guided_regex: Optional[str] = Field( default=None, @@ -633,9 +700,34 @@ class ToolCall(OpenAIBaseModel): function: FunctionCall +class DeltaFunctionCall(BaseModel): + name: Optional[str] = None + arguments: Optional[str] = None + + +# a tool call delta where everything is optional +class DeltaToolCall(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"chatcmpl-tool-{random_uuid()}") + type: Literal["function"] = "function" + index: int + function: Optional[DeltaFunctionCall] = None + + +class ExtractedToolCallInformation(BaseModel): + # indicate if tools were called + tools_called: bool + + # extracted tool calls + tool_calls: List[ToolCall] + + # content - per OpenAI spec, content AND tool calls can be returned rarely + # But some models will do this intentionally + content: Optional[str] = None + + class ChatMessage(OpenAIBaseModel): role: str - content: str + content: Optional[str] = None tool_calls: List[ToolCall] = Field(default_factory=list) @@ -657,7 +749,9 @@ class ChatCompletionResponseChoice(OpenAIBaseModel): index: int message: ChatMessage logprobs: Optional[ChatCompletionLogProbs] = None - finish_reason: Optional[str] = None + # per OpenAI spec this is the default + finish_reason: Optional[str] = "stop" + # not part of the OpenAI spec but included in vLLM for legacy reasons stop_reason: Optional[Union[int, str]] = None @@ -674,7 +768,7 @@ class ChatCompletionResponse(OpenAIBaseModel): class DeltaMessage(OpenAIBaseModel): role: Optional[str] = None content: Optional[str] = None - tool_calls: List[ToolCall] = Field(default_factory=list) + tool_calls: List[DeltaToolCall] = Field(default_factory=list) class ChatCompletionResponseStreamChoice(OpenAIBaseModel): @@ -777,3 +871,13 @@ class DetokenizeRequest(OpenAIBaseModel): class DetokenizeResponse(OpenAIBaseModel): prompt: str + + +class LoadLoraAdapterRequest(BaseModel): + lora_name: str + lora_path: str + + +class UnloadLoraAdapterRequest(BaseModel): + lora_name: str + lora_int_id: Optional[int] = Field(default=None) diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index a472e12e8ca48..9b88db746be5c 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -1,11 +1,14 @@ import asyncio +import pickle from contextlib import contextmanager, suppress -from typing import Any, AsyncGenerator, Mapping, Optional +from typing import Any, AsyncGenerator, Iterator, Mapping, Optional from uuid import uuid4 import cloudpickle import zmq import zmq.asyncio +from zmq import Frame # type: ignore[attr-defined] +from zmq.asyncio import Socket from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig) @@ -115,18 +118,21 @@ def __init__(self, rpc_path: str): self.context.set(zmq.constants.MAX_SOCKETS, socket_limit) # IPC connection to RPC Server (uses unix sockets). - self.to_rpc_server = self.context.socket(zmq.constants.DEALER) + self.to_rpc_server: Socket = self.context.socket(zmq.constants.DEALER) self.to_rpc_server.set_hwm(VLLM_RPC_ZMQ_HWM) self.to_rpc_server.bind(rpc_path) # In process proxy to RPC Server (uses memory-based messaging). - self.from_api_server = self.context.socket(zmq.constants.ROUTER) + self.from_api_server: Socket = self.context.socket( + zmq.constants.ROUTER) self.from_api_server.set_hwm(VLLM_RPC_ZMQ_HWM) self.from_api_server.bind(INPROC_PROXY_PATH) # Asyncio background task for the proxy. - self.proxy_task = asyncio.create_task( + self.proxy_in_task = asyncio.create_task( self.run_proxy(self.from_api_server, self.to_rpc_server)) + self.proxy_out_task = asyncio.create_task( + self.run_proxy(self.to_rpc_server, self.from_api_server)) # Since we open 1 inproc socket per request, we have a hard cap on # the number of requests that can run in vLLM w. frontend @@ -136,20 +142,11 @@ def __init__(self, rpc_path: str): # 1 for generate(), 1 for abort(), do_log_stats(), check_health() self.limit_concurrency = socket_limit // 2 - 2 - async def run_proxy(self, socket_from, socket_to): + async def run_proxy(self, socket_from: Socket, socket_to: Socket): """Background task that runs a proxy""" - poller = zmq.asyncio.Poller() - poller.register(socket_from, zmq.constants.POLLIN) - poller.register(socket_to, zmq.constants.POLLIN) while True: - events_lst = await poller.poll() - events = dict(events_lst) - if socket_from in events: - identity, msg = await socket_from.recv_multipart() - await socket_to.send_multipart([identity, msg]) - if socket_to in events: - identity, msg = await socket_to.recv_multipart() - await socket_from.send_multipart([identity, msg]) + frames = await socket_from.recv_multipart(copy=False) + await socket_to.send_multipart(frames, copy=False) async def setup(self): """Setup the client before it starts sending server requests.""" @@ -180,7 +177,7 @@ def close(self): self.context.destroy() @contextmanager - def to_proxy_socket(self): + def to_proxy_socket(self) -> Iterator[Socket]: # Connect to the RPCServer via the proxy. # Raise a sensible error if the client was already closed. @@ -208,7 +205,8 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, with self.to_proxy_socket() as socket: # Ping RPCServer with a request. - await socket.send_multipart([cloudpickle.dumps(request)]) + await socket.send_multipart((cloudpickle.dumps(request), ), + copy=False) # Make sure the server responds if await socket.poll(timeout=self._data_timeout) == 0: @@ -216,7 +214,9 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, f"{self._data_timeout} ms") # Await the data from the Server. - data = cloudpickle.loads(await socket.recv()) + frame = await socket.recv(copy=False) + assert isinstance(frame, Frame) + data = pickle.loads(frame.buffer) if isinstance(data, Exception): # Re-raise exceptions returned by the server @@ -234,23 +234,23 @@ async def _send_get_data_rpc_request(self, request: RPCUtilityRequest, return data - async def _send_one_way_rpc_request( - self, - request: RPC_REQUEST_TYPE, - error_message: str, - socket: Optional[zmq.asyncio.Socket] = None): + async def _send_one_way_rpc_request(self, + request: RPC_REQUEST_TYPE, + error_message: str, + socket: Optional[Socket] = None): """Send one-way RPC request to trigger an action.""" - async def do_rpc_call(socket: zmq.asyncio.Socket, - request: RPC_REQUEST_TYPE): + async def do_rpc_call(socket: Socket, request: RPC_REQUEST_TYPE): - await socket.send_multipart([cloudpickle.dumps(request)]) + await socket.send_multipart((cloudpickle.dumps(request), )) if await socket.poll(timeout=self._data_timeout) == 0: raise TimeoutError("Server didn't reply within " f"{self._data_timeout} ms") - return cloudpickle.loads(await socket.recv()) + frame = await socket.recv(copy=False) + assert isinstance(frame, Frame) + return pickle.loads(frame.buffer) # Make a new socket connection. if socket is None: @@ -386,21 +386,20 @@ async def generate( try: with self.to_proxy_socket() as socket: # Send RPCGenerateRequest to the RPCServer. - await socket.send_multipart([ - cloudpickle.dumps( - RPCGenerateRequest( - inputs=inputs, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request)) - ]) + await socket.send_multipart((cloudpickle.dumps( + RPCGenerateRequest( + inputs=inputs, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request)), )) # Stream back the results from the RPC Server. while not finished: - message = await socket.recv() - request_output = cloudpickle.loads(message) + message = await socket.recv(copy=False) + assert isinstance(message, Frame) + request_output = pickle.loads(message.buffer) if isinstance(request_output, Exception): # On exception, check if the server is still healthy @@ -424,9 +423,7 @@ async def generate( if not finished and not self._errored: await self.abort(request_id) - async def check_health(self, - socket: Optional[zmq.asyncio.Socket] = None - ) -> None: + async def check_health(self, socket: Optional[Socket] = None) -> None: """Raise if unhealthy""" await self._send_one_way_rpc_request( @@ -451,4 +448,4 @@ async def stop_profile(self) -> None: await self._send_one_way_rpc_request( request=RPCUtilityRequest.STOP_PROFILE, - error_message="RPCRequest STOP_PROFILE failed.") \ No newline at end of file + error_message="RPCRequest STOP_PROFILE failed.") diff --git a/vllm/entrypoints/openai/rpc/server.py b/vllm/entrypoints/openai/rpc/server.py index 738d12bbef051..bebc2faedb680 100644 --- a/vllm/entrypoints/openai/rpc/server.py +++ b/vllm/entrypoints/openai/rpc/server.py @@ -1,4 +1,5 @@ import asyncio +import pickle import signal from typing import Any, Coroutine, Union @@ -7,6 +8,8 @@ import zmq import zmq.asyncio from typing_extensions import Never +from zmq import Frame # type: ignore[attr-defined] +from zmq.asyncio import Socket from vllm import AsyncEngineArgs, AsyncLLMEngine from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, @@ -35,7 +38,7 @@ def __init__(self, async_engine_args: AsyncEngineArgs, self.context = zmq.asyncio.Context() # Init socket. - self.socket = self.context.socket(zmq.constants.DEALER) + self.socket: Socket = self.context.socket(zmq.constants.DEALER) self.socket.set_hwm(VLLM_RPC_ZMQ_HWM) self.socket.connect(rpc_path) @@ -63,30 +66,31 @@ async def get_config(self, identity, request): else: raise ValueError("Unknown Config Request: %s", request) - await self.socket.send_multipart( - [identity, cloudpickle.dumps(config)]) + await self.socket.send_multipart((identity, pickle.dumps(config)), + copy=False) except Exception as e: - await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + await self.socket.send_multipart((identity, pickle.dumps(e)), + copy=False) async def is_tracing_enabled(self, identity): """Send the is_tracing_enabled flag""" tracing_flag = await self.engine.is_tracing_enabled() await self.socket.send_multipart( - [identity, cloudpickle.dumps(tracing_flag)]) + (identity, pickle.dumps(tracing_flag))) async def do_log_stats(self, identity): """Log stats and confirm success.""" await self.engine.do_log_stats() await self.socket.send_multipart( - [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)]) + (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) async def is_server_ready(self, identity): """Notify the client that we are ready.""" await self.socket.send_multipart( - [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)]) + (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) async def abort(self, identity, request: RPCAbortRequest): """Abort request and notify the client of success.""" @@ -96,7 +100,7 @@ async def abort(self, identity, request: RPCAbortRequest): result: Union[str, Exception] = VLLM_RPC_SUCCESS_STR except Exception as e: result = e - await self.socket.send_multipart([identity, cloudpickle.dumps(result)]) + await self.socket.send_multipart((identity, pickle.dumps(result))) async def generate(self, identity, generate_request: RPCGenerateRequest): try: @@ -110,45 +114,47 @@ async def generate(self, identity, generate_request: RPCGenerateRequest): async for request_output in results_generator: await self.socket.send_multipart( - [identity, cloudpickle.dumps(request_output)]) + (identity, pickle.dumps(request_output)), copy=False) except Exception as e: - await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + await self.socket.send_multipart((identity, pickle.dumps(e)), + copy=False) async def check_health(self, identity): try: await self.engine.check_health() await self.socket.send_multipart( - [identity, cloudpickle.dumps(VLLM_RPC_SUCCESS_STR)]) + (identity, pickle.dumps(VLLM_RPC_SUCCESS_STR))) except Exception as e: - await self.socket.send_multipart([identity, cloudpickle.dumps(e)]) + await self.socket.send_multipart((identity, pickle.dumps(e)), + copy=False) async def start_profile(self, identity): logger.info("Starting profiler...") await self.engine.start_profile() logger.info("Profiler started.") - await self.socket.send_multipart([ + await self.socket.send_multipart(( identity, - cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), - ]) + pickle.dumps(VLLM_RPC_SUCCESS_STR), + )) async def stop_profile(self, identity): logger.info("Stopping profiler...") await self.engine.stop_profile() logger.info("Profiler stopped.") - await self.socket.send_multipart([ + await self.socket.send_multipart(( identity, - cloudpickle.dumps(VLLM_RPC_SUCCESS_STR), - ]) + pickle.dumps(VLLM_RPC_SUCCESS_STR), + )) def _make_handler_coro(self, identity, - message) -> Coroutine[Any, Any, Never]: + message: Frame) -> Coroutine[Any, Any, Never]: """Route the zmq message to the handler coroutine.""" - request = cloudpickle.loads(message) + request = cloudpickle.loads(message.buffer) if isinstance(request, RPCGenerateRequest): return self.generate(identity, request) @@ -189,7 +195,7 @@ async def run_server_loop(self): running_tasks = set() while True: # Wait for a request. - identity, message = await self.socket.recv_multipart() + identity, message = await self.socket.recv_multipart(copy=False) # Process the request async. task = asyncio.create_task( diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 32bbade256973..278be8cd11a12 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -1,9 +1,11 @@ import asyncio from io import StringIO -from typing import Awaitable, Callable, List +from typing import Awaitable, Callable, List, Optional import aiohttp +import torch from prometheus_client import start_http_server +from tqdm import tqdm from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.engine.async_llm_engine import AsyncLLMEngine @@ -78,6 +80,38 @@ def parse_args(): return parser.parse_args() +# explicitly use pure text format, with a newline at the end +# this makes it impossible to see the animation in the progress bar +# but will avoid messing up with ray or multiprocessing, which wraps +# each line of output with some prefix. +_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501 + + +class BatchProgressTracker: + + def __init__(self): + self._total = 0 + self._pbar: Optional[tqdm] = None + + def submitted(self): + self._total += 1 + + def completed(self): + if self._pbar: + self._pbar.update() + + def pbar(self) -> tqdm: + enable_tqdm = not torch.distributed.is_initialized( + ) or torch.distributed.get_rank() == 0 + self._pbar = tqdm(total=self._total, + unit="req", + desc="Running batch", + mininterval=5, + disable=not enable_tqdm, + bar_format=_BAR_FORMAT) + return self._pbar + + async def read_file(path_or_url: str) -> str: if path_or_url.startswith("http://") or path_or_url.startswith("https://"): async with aiohttp.ClientSession() as session, \ @@ -102,7 +136,8 @@ async def write_file(path_or_url: str, data: str) -> None: async def run_request(serving_engine_func: Callable, - request: BatchRequestInput) -> BatchRequestOutput: + request: BatchRequestInput, + tracker: BatchProgressTracker) -> BatchRequestOutput: response = await serving_engine_func(request.body) if isinstance(response, (ChatCompletionResponse, EmbeddingResponse)): @@ -125,6 +160,7 @@ async def run_request(serving_engine_func: Callable, else: raise ValueError("Request must not be sent in stream mode") + tracker.completed() return batch_output @@ -164,6 +200,9 @@ async def main(args): request_logger=request_logger, ) + tracker = BatchProgressTracker() + logger.info("Reading batch from %s...", args.input_file) + # Submit all requests in the file to the engine "concurrently". response_futures: List[Awaitable[BatchRequestOutput]] = [] for request_json in (await read_file(args.input_file)).strip().split("\n"): @@ -178,16 +217,19 @@ async def main(args): if request.url == "/v1/chat/completions": response_futures.append( run_request(openai_serving_chat.create_chat_completion, - request)) + request, tracker)) + tracker.submitted() elif request.url == "/v1/embeddings": response_futures.append( - run_request(openai_serving_embedding.create_embedding, - request)) + run_request(openai_serving_embedding.create_embedding, request, + tracker)) + tracker.submitted() else: raise ValueError("Only /v1/chat/completions and /v1/embeddings are" "supported in the batch endpoint.") - responses = await asyncio.gather(*response_futures) + with tracker.pbar(): + responses = await asyncio.gather(*response_futures) output_buffer = StringIO() for response in responses: diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index d31ac4995fe2f..8ed81e9c88cb2 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,6 +1,8 @@ import asyncio +import json import time -from typing import AsyncGenerator, AsyncIterator, Dict, Final, List, Optional +from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, Final, List, + Optional) from typing import Sequence as GenericSequence from typing import Union @@ -11,23 +13,25 @@ from vllm.entrypoints.chat_utils import (ConversationMessage, apply_chat_template, load_chat_template, - parse_chat_messages) + parse_chat_messages_futures) from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, ChatCompletionLogProbs, ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, - FunctionCall, ToolCall, UsageInfo) + ChatCompletionStreamResponse, ChatMessage, DeltaFunctionCall, DeltaMessage, + DeltaToolCall, ErrorResponse, FunctionCall, ToolCall, UsageInfo) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing, PromptAdapterPath, TextTokensPrompt) +from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser, + MistralToolParser, + ToolParser) from vllm.inputs import TokensPrompt from vllm.logger import init_logger -from vllm.multimodal import MultiModalDataDict -from vllm.outputs import RequestOutput +from vllm.outputs import CompletionOutput, RequestOutput from vllm.sequence import Logprob from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) @@ -39,19 +43,19 @@ class OpenAIServingChat(OpenAIServing): - def __init__( - self, - async_engine_client: AsyncEngineClient, - model_config: ModelConfig, - served_model_names: List[str], - response_role: str, - *, - lora_modules: Optional[List[LoRAModulePath]], - prompt_adapters: Optional[List[PromptAdapterPath]], - request_logger: Optional[RequestLogger], - chat_template: Optional[str], - return_tokens_as_token_ids: bool = False, - ): + def __init__(self, + async_engine_client: AsyncEngineClient, + model_config: ModelConfig, + served_model_names: List[str], + response_role: str, + *, + lora_modules: Optional[List[LoRAModulePath]], + prompt_adapters: Optional[List[PromptAdapterPath]], + request_logger: Optional[RequestLogger], + chat_template: Optional[str], + return_tokens_as_token_ids: bool = False, + enable_auto_tools: bool = False, + tool_parser: Optional[str] = None): super().__init__(async_engine_client=async_engine_client, model_config=model_config, served_model_names=served_model_names, @@ -61,10 +65,27 @@ def __init__( return_tokens_as_token_ids=return_tokens_as_token_ids) self.response_role = response_role - - # If this is None we use the tokenizer's default chat template + self.use_tool_use_model_template = False self.chat_template = load_chat_template(chat_template) + # set up tool use + self.enable_auto_tools: bool = enable_auto_tools + if self.enable_auto_tools: + logger.info( + "\"auto\" tool choice has been enabled please note that while" + " the parallel_tool_calls client option is preset for " + "compatibility reasons, it will be ignored.") + + self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None + if self.enable_auto_tools: + if tool_parser == "mistral": + self.tool_parser = MistralToolParser + elif tool_parser == "hermes": + self.tool_parser = Hermes2ProToolParser + else: + raise TypeError("Error: --enable-auto-tool-choice requires " + "--tool-call-parser") + async def create_chat_completion( self, request: ChatCompletionRequest, @@ -77,11 +98,10 @@ async def create_chat_completion( for the API specification. This API mimics the OpenAI ChatCompletion API. - NOTE: Currently we do not support the following feature: - - function_call (Users should implement this by themselves) """ error_check_ret = await self._check_model(request) if error_check_ret is not None: + logger.error("Error with model %s", error_check_ret) return error_check_ret try: @@ -94,7 +114,7 @@ async def create_chat_completion( tokenizer = await self.async_engine_client.get_tokenizer( lora_request) - conversation, mm_futures = parse_chat_messages( + conversation, mm_data_future = parse_chat_messages_futures( request.messages, model_config, tokenizer) tool_dicts = None if request.tools is None else [ @@ -114,18 +134,26 @@ async def create_chat_completion( logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) - mm_data: Optional[MultiModalDataDict] = None try: - if len(mm_futures): - # since we support only single mm data currently - assert len( - mm_futures - ) == 1, "Multiple 'image_url' input is currently not supported." - mm_data = await mm_futures[0] + mm_data = await mm_data_future except Exception as e: logger.error("Error in loading multi-modal data: %s", e) return self.create_error_response(str(e)) + # validation for OpenAI tools + # tool_choice = "required" is not supported + if request.tool_choice == "required": + return self.create_error_response( + "tool_choice = \"required\" is not supported!") + + # "auto" tools requires --enable-auto-tool-choice + # and --tool-call-parser + if request.tool_choice == "auto" and not ( + self.enable_auto_tools and self.tool_parser is not None): + return self.create_error_response( + "\"auto\" tool choice requires " + "--enable-auto-tool-choice and --tool-call-parser to be set") + request_id = f"chat-{random_uuid()}" try: guided_decode_logits_processor = ( @@ -194,6 +222,7 @@ async def create_chat_completion( if request.stream: return self.chat_completion_stream_generator( request, result_generator, request_id, conversation, tokenizer) + try: return await self.chat_completion_full_generator( request, result_generator, request_id, conversation, tokenizer) @@ -226,6 +255,9 @@ async def chat_completion_stream_generator( previous_num_tokens = [0] * num_choices finish_reason_sent = [False] * num_choices + tool_parser: Optional[ToolParser] = self.tool_parser( + tokenizer) if self.tool_parser else None + try: async for res in result_generator: # We need to do it here, because if there are exceptions in @@ -235,10 +267,17 @@ async def chat_completion_stream_generator( # Send first response for each request.n (index) with # the role role = self.get_chat_request_role(request) + + # NOTE num_choices defaults to 1 so this usually executes + # once per request for i in range(num_choices): + choice_data = ChatCompletionResponseStreamChoice( index=i, - delta=DeltaMessage(role=role), + delta=DeltaMessage( + role=role, + content="", + ), logprobs=None, finish_reason=None) chunk = ChatCompletionStreamResponse( @@ -247,14 +286,18 @@ async def chat_completion_stream_generator( created=created_time, choices=[choice_data], model=model_name) + + # if usage should be included if (request.stream_options and request.stream_options.include_usage): - if (request.stream_options.continuous_usage_stats): + # if continuous usage stats are requested, add it + if request.stream_options.continuous_usage_stats: prompt_tokens = len(res.prompt_token_ids) usage = UsageInfo(prompt_tokens=prompt_tokens, completion_tokens=0, total_tokens=prompt_tokens) chunk.usage = usage + # otherwise don't else: chunk.usage = None @@ -264,7 +307,7 @@ async def chat_completion_stream_generator( # Send response to echo the input portion of the # last message if request.echo: - last_msg_content = "" + last_msg_content: Optional[str] = "" if conversation and conversation[-1].get( "content") and conversation[-1].get( "role") == role: @@ -305,6 +348,7 @@ async def chat_completion_stream_generator( first_iteration = False for output in res.outputs: + i = output.index if finish_reason_sent[i]: @@ -327,20 +371,50 @@ async def chat_completion_stream_generator( logprobs = None delta_text = output.text[len(previous_texts[i]):] - previous_texts[i] = output.text - previous_num_tokens[i] = len(output.token_ids) + delta_message: Optional[DeltaMessage] = None - if request.tool_choice and type( - request.tool_choice - ) is ChatCompletionNamedToolChoiceParam: + # handle streaming deltas for tools with named tool_choice + if (request.tool_choice and type(request.tool_choice) is + ChatCompletionNamedToolChoiceParam): delta_message = DeltaMessage(tool_calls=[ - ToolCall(function=FunctionCall( + DeltaToolCall(function=DeltaFunctionCall( name=request.tool_choice.function.name, - arguments=delta_text)) + arguments=delta_text), + index=i) ]) + + # handle streaming deltas for tools with "auto" tool choice + elif (self._should_stream_with_auto_tool_parsing(request) + and tool_parser): + delta_message = ( + tool_parser.extract_tool_calls_streaming( + previous_text=previous_texts[i], + current_text=output.text, + delta_text=delta_text, + previous_token_ids= \ + output.token_ids[ + :-1 * len(delta_token_ids) + ], + current_token_ids=output.token_ids, + delta_token_ids=delta_token_ids + ) + ) + + # handle streaming just a content delta else: delta_message = DeltaMessage(content=delta_text) + # set the previous values for the next iteration + previous_texts[i] = output.text + previous_num_tokens[i] = len(output.token_ids) + + # if the message delta is None (e.g. because it was a + # "control token" for tool calls or the parser otherwise + # wasn't ready to send a token, then + # get the next token without streaming a chunk + if delta_message is None: + continue + if output.finish_reason is None: # Send token-by-token response for each request.n @@ -355,6 +429,8 @@ async def chat_completion_stream_generator( created=created_time, choices=[choice_data], model=model_name) + + # handle usage stats if requested & if continuous if (request.stream_options and request.stream_options.include_usage): if (request.stream_options.continuous_usage_stats): @@ -372,14 +448,55 @@ async def chat_completion_stream_generator( data = chunk.model_dump_json(exclude_unset=True) yield f"data: {data}\n\n" + + # if the model is finished generating else: + # check to make sure we haven't "forgotten" to stream + # any tokens that were generated but previously + # matched by partial json parsing + # only happens if we are NOT using guided decoding + if tool_parser: + index = len( + tool_parser.prev_tool_call_arr) - 1 if len( + tool_parser.prev_tool_call_arr) > 0 else 0 + else: + index = 0 + + if self._should_check_for_unstreamed_tool_arg_tokens( + delta_message, output) and tool_parser: + # get the expected call based on partial JSON + # parsing which "autocompletes" the JSON + expected_call = json.dumps( + tool_parser.prev_tool_call_arr[index].get( + "arguments", {})) + + # get what we've streamed so for for arguments + # for the current tool + actual_call = tool_parser.streamed_args_for_tool[ + index] + + # check to see if there's anything left to stream + remaining_call = expected_call.replace( + actual_call, "", 1) + + # set that as a delta message + delta_message = DeltaMessage(tool_calls=[ + DeltaToolCall(index=index, + function=DeltaFunctionCall( + arguments=remaining_call). + model_dump(exclude_none=True)) + ]) + # Send the finish response for each request.n only once prompt_tokens = len(res.prompt_token_ids) choice_data = ChatCompletionResponseStreamChoice( index=i, delta=delta_message, logprobs=logprobs, - finish_reason=output.finish_reason, + finish_reason=output.finish_reason + if not (tool_parser + and len(tool_parser.prev_tool_call_arr)) + else "tool_calls", stop_reason=output.stop_reason) chunk = ChatCompletionStreamResponse( id=request_id, @@ -405,6 +522,8 @@ async def chat_completion_stream_generator( yield f"data: {data}\n\n" finish_reason_sent[i] = True + # once the final token is handled, if stream_options.include_usage + # is sent, send the usage if (request.stream_options and request.stream_options.include_usage): final_usage = UsageInfo( @@ -426,6 +545,7 @@ async def chat_completion_stream_generator( except ValueError as e: # TODO: Use a vllm-specific Validation Error + logger.error("error in chat completion stream generator: %s", e) data = self.create_streaming_error_response(str(e)) yield f"data: {data}\n\n" # Send the final done message after all response.n are finished @@ -470,8 +590,21 @@ async def chat_completion_full_generator( else: logprobs = None - if request.tool_choice and type( + # by default, tools are not used. + tools_called = False + + # if auto tools are not enabled, and a named tool choice using + # outlines is not being used + if not (self.enable_auto_tools + or not self.tool_parser) and not isinstance( + request.tool_choice, + ChatCompletionNamedToolChoiceParam): + message = ChatMessage(role=role, content=output.text) + + # if the request uses tools and specified a tool choice + elif request.tool_choice and type( request.tool_choice) is ChatCompletionNamedToolChoiceParam: + message = ChatMessage( role=role, content="", @@ -480,14 +613,47 @@ async def chat_completion_full_generator( name=request.tool_choice.function.name, arguments=output.text)) ]) + tools_called = True + + # if the request doesn't use tool choice + # OR specifies to not use a tool elif not request.tool_choice or request.tool_choice == "none": + + message = ChatMessage(role=role, content=output.text) + + # handle when there are tools and tool choice is auto + elif request.tools and ( + request.tool_choice == "auto" + or request.tool_choice is None) and self.enable_auto_tools \ + and self.tool_parser: + + tool_parser = self.tool_parser(tokenizer) + tool_call_info = tool_parser.extract_tool_calls(output.text) + tools_called = tool_call_info.tools_called + if tool_call_info.tools_called: + message = ChatMessage(role=role, + content=tool_call_info.content, + tool_calls=tool_call_info.tool_calls) + + else: + # FOR NOW make it a chat message; we will have to detect + # the type to make it later. + message = ChatMessage(role=role, content=output.text) + + # undetermined case that is still important to handle + else: + logger.error( + "Error in chat_completion_full_generator - cannot determine" + " if tools should be extracted. Returning a standard chat " + "completion.") message = ChatMessage(role=role, content=output.text) choice_data = ChatCompletionResponseChoice( index=output.index, message=message, logprobs=logprobs, - finish_reason=output.finish_reason, + finish_reason="tool_calls" if tools_called else + output.finish_reason if output.finish_reason else "stop", stop_reason=output.stop_reason) choices.append(choice_data) @@ -495,10 +661,11 @@ async def chat_completion_full_generator( last_msg_content = "" if conversation and conversation[-1].get( "content") and conversation[-1].get("role") == role: - last_msg_content = conversation[-1]["content"] + last_msg_content = conversation[-1]["content"] or "" for choice in choices: - full_message = last_msg_content + choice.message.content + full_message = last_msg_content + (choice.message.content + or "") choice.message.content = full_message num_prompt_tokens = len(final_res.prompt_token_ids) @@ -581,3 +748,38 @@ def _create_chat_logprobs( )) return ChatCompletionLogProbs(content=logprobs_content) + + def _should_stream_with_auto_tool_parsing(self, + request: ChatCompletionRequest): + """ + Utility function to check if streamed tokens should go through the tool + call parser that was configured. + + We only want to do this IF user-provided tools are set, a tool parser + is configured, "auto" tool choice is enabled, and the request's tool + choice field indicates that "auto" tool choice should be used. + """ + return (request.tools and self.tool_parser and self.enable_auto_tools + and request.tool_choice in ['auto', None]) + + def _should_check_for_unstreamed_tool_arg_tokens( + self, + delta_message: Optional[DeltaMessage], + output: CompletionOutput, + ) -> bool: + """ + Check to see if we should check for unstreamed tool arguments tokens. + This is only applicable when auto tool parsing is enabled, the delta + is a tool call with arguments. + """ + + # yapf: disable + return bool( + # if there is a delta message that includes tool calls which + # include a function that has arguments + self.enable_auto_tools and self.tool_parser and delta_message + and delta_message.tool_calls and delta_message.tool_calls[0] + and delta_message.tool_calls[0].function + and delta_message.tool_calls[0].function.arguments is not None + and output.finish_reason is not None + ) diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 26e91e7cc94dd..ac74527441cd9 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -16,11 +16,13 @@ CompletionRequest, DetokenizeRequest, EmbeddingRequest, ErrorResponse, + LoadLoraAdapterRequest, ModelCard, ModelList, ModelPermission, TokenizeChatRequest, TokenizeCompletionRequest, - TokenizeRequest) + TokenizeRequest, + UnloadLoraAdapterRequest) # yapf: enable from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger @@ -32,6 +34,7 @@ from vllm.sampling_params import LogitsProcessor, SamplingParams from vllm.sequence import Logprob from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import AtomicCounter logger = init_logger(__name__) @@ -78,6 +81,7 @@ def __init__( self.served_model_names = served_model_names + self.lora_id_counter = AtomicCounter(0) self.lora_requests = [] if lora_modules is not None: self.lora_requests = [ @@ -403,3 +407,76 @@ def _get_decoded_token(logprob: Logprob, if logprob.decoded_token is not None: return logprob.decoded_token return tokenizer.decode(token_id) + + async def _check_load_lora_adapter_request( + self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]: + # Check if both 'lora_name' and 'lora_path' are provided + if not request.lora_name or not request.lora_path: + return self.create_error_response( + message="Both 'lora_name' and 'lora_path' must be provided.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + # Check if the lora adapter with the given name already exists + if any(lora_request.lora_name == request.lora_name + for lora_request in self.lora_requests): + return self.create_error_response( + message= + f"The lora adapter '{request.lora_name}' has already been" + "loaded.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + return None + + async def _check_unload_lora_adapter_request( + self, + request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]: + # Check if either 'lora_name' or 'lora_int_id' is provided + if not request.lora_name and not request.lora_int_id: + return self.create_error_response( + message= + "either 'lora_name' and 'lora_int_id' needs to be provided.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + # Check if the lora adapter with the given name exists + if not any(lora_request.lora_name == request.lora_name + for lora_request in self.lora_requests): + return self.create_error_response( + message= + f"The lora adapter '{request.lora_name}' cannot be found.", + err_type="InvalidUserInput", + status_code=HTTPStatus.BAD_REQUEST) + + return None + + async def load_lora_adapter( + self, + request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]: + error_check_ret = await self._check_load_lora_adapter_request(request) + if error_check_ret is not None: + return error_check_ret + + lora_name, lora_path = request.lora_name, request.lora_path + unique_id = self.lora_id_counter.inc(1) + self.lora_requests.append( + LoRARequest(lora_name=lora_name, + lora_int_id=unique_id, + lora_path=lora_path)) + return f"Success: LoRA adapter '{lora_name}' added successfully." + + async def unload_lora_adapter( + self, + request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]: + error_check_ret = await self._check_unload_lora_adapter_request(request + ) + if error_check_ret is not None: + return error_check_ret + + lora_name = request.lora_name + self.lora_requests = [ + lora_request for lora_request in self.lora_requests + if lora_request.lora_name != lora_name + ] + return f"Success: LoRA adapter '{lora_name}' removed successfully." diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 1aeabb7a7d729..69a5ad5b62cfa 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -4,7 +4,7 @@ from vllm.engine.protocol import AsyncEngineClient from vllm.entrypoints.chat_utils import (apply_chat_template, load_chat_template, - parse_chat_messages) + parse_chat_messages_futures) from vllm.entrypoints.logger import RequestLogger # yapf conflicts with isort for this block # yapf: disable @@ -43,7 +43,11 @@ def __init__( request_logger=request_logger) # If this is None we use the tokenizer's default chat template - self.chat_template = load_chat_template(chat_template) + # the list of commonly-used chat template names for HF named templates + hf_chat_templates: List[str] = ['default', 'tool_use'] + self.chat_template = chat_template \ + if chat_template in hf_chat_templates \ + else load_chat_template(chat_template) async def create_tokenize( self, @@ -65,10 +69,11 @@ async def create_tokenize( if isinstance(request, TokenizeChatRequest): model_config = self.model_config - conversation, mm_futures = parse_chat_messages( + conversation, mm_data_future = parse_chat_messages_futures( request.messages, model_config, tokenizer) - if mm_futures: + mm_data = await mm_data_future + if mm_data: logger.warning( "Multi-modal inputs are ignored during tokenization") diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py new file mode 100644 index 0000000000000..5d5d53784fedf --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -0,0 +1,5 @@ +from .abstract_tool_parser import ToolParser +from .hermes_tool_parser import Hermes2ProToolParser +from .mistral_tool_parser import MistralToolParser + +__all__ = ["ToolParser", "Hermes2ProToolParser", "MistralToolParser"] \ No newline at end of file diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py new file mode 100644 index 0000000000000..873f615d43257 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -0,0 +1,57 @@ +from typing import Dict, List, Sequence, Union + +from vllm.entrypoints.openai.protocol import (DeltaMessage, + ExtractedToolCallInformation) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer + +logger = init_logger(__name__) + + +class ToolParser: + """ + Abstract ToolParser class that should not be used directly. Provided + properties and methods should be used in + derived classes. + """ + + def __init__(self, tokenizer: AnyTokenizer): + self.prev_tool_call_arr: List[Dict] = [] + # the index of the tool call that is currently being parsed + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: List[str] = [] + + self.model_tokenizer = tokenizer + + def extract_tool_calls(self, + model_output: str) -> ExtractedToolCallInformation: + """ + Static method that should be implemented for extracting tool calls from + a complete model-generated string. + Used for non-streaming responses where we have the entire model response + available before sending to the client. + Static because it's stateless. + """ + raise NotImplementedError( + "AbstractToolParser.extract_tool_calls has not been implemented!") + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + """ + Instance method that should be implemented for extracting tool calls + from an incomplete response; for use when handling tool calls and + streaming. Has to be an instance method because it requires state - + the current tokens/diffs, but also the information about what has + previously been parsed and extracted (see constructor) + """ + raise NotImplementedError( + "AbstractToolParser.extract_tool_calls_streaming has not been " + "implemented!") diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py new file mode 100644 index 0000000000000..bde9b47ce60d5 --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -0,0 +1,334 @@ +import json +import re +from typing import Dict, List, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +class Hermes2ProToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if isinstance(self.model_tokenizer, MistralTokenizer): + logger.error( + "Detected Mistral tokenizer when using a Hermes model") + self.model_tokenizer = self.model_tokenizer.tokenizer + + self.current_tool_name_sent: bool = False + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent = False + self.streamed_args_for_tool: List[str] = [ + ] # map what has been streamed for each tool so far to a list + + self.tool_call_start_token: str = "" + self.tool_call_end_token: str = "" + + self.tool_call_regex = re.compile( + r"(.*?)|(.*)", re.DOTALL) + self.scratch_pad_regex = re.compile( + r"(.*?)", re.DOTALL) + + if not self.model_tokenizer: + raise ValueError( + "The model tokenizer must be passed to the ToolParser " + "constructor during construction.") + self.tool_call_start_token_id: int = self.model_tokenizer.vocab[ + self.tool_call_start_token] + self.tool_call_end_token_id: int = self.model_tokenizer.vocab[ + self.tool_call_end_token] + if not self.tool_call_start_token_id or not self.tool_call_end_token_id: + raise RuntimeError( + "Hermes 2 Pro Tool parser could not locate tool call start/end " + "tokens in the tokenizer!") + + def extract_tool_calls(self, + model_output: str) -> ExtractedToolCallInformation: + + # sanity check; avoid unnecessary processing + if self.tool_call_start_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + else: + + try: + # there are two possible captures - between tags, or between a + # tag and end-of-string so the result of + # findall is an array of tuples where one is a function call and + # the other is None + function_call_tuples = ( + self.tool_call_regex.findall(model_output)) + + # load the JSON, and then use it to build the Function and + # Tool Call + raw_function_calls = [ + json.loads(match[0] if match[0] else match[1]) + for match in function_call_tuples + ] + tool_calls = [ + ToolCall( + type="function", + function=FunctionCall( + name=function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(function_call["arguments"]))) + for function_call in raw_function_calls + ] + + content = model_output[:model_output. + find(self.tool_call_start_token)] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if content else None) + + except Exception as e: + logger.error("Error in extracting tool call from response %s", + e) + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + + logger.debug("delta_text: %s", delta_text) + logger.debug("delta_token_ids: %s", delta_token_ids) + # check to see if we should be streaming a tool call - is there a + if self.tool_call_start_token_id not in current_token_ids: + logger.debug("No tool call tokens found!") + return DeltaMessage(content=delta_text) + + try: + + # figure out where we are in the parsing by counting tool call + # start & end tags + prev_tool_start_count = previous_token_ids.count( + self.tool_call_start_token_id) + prev_tool_end_count = previous_token_ids.count( + self.tool_call_end_token_id) + cur_tool_start_count = current_token_ids.count( + self.tool_call_start_token_id) + cur_tool_end_count = current_token_ids.count( + self.tool_call_end_token_id) + + # case: if we're generating text, OR rounding out a tool call + if (cur_tool_start_count == cur_tool_end_count + and prev_tool_end_count == cur_tool_end_count): + logger.debug("Generating text content! skipping tool parsing.") + if delta_text != self.tool_call_end_token: + return DeltaMessage(content=delta_text) + + # case: if tool open & close tag counts don't match, we're doing + # imaginary "else" block here + # something with tools with this diff. + # flags for partial JSON parting. exported constants from + # "Allow" are handled via BIT MASK + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + + # case -- we're starting a new tool call + if (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count > prev_tool_start_count): + if len(delta_token_ids) > 1: + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + else: + tool_call_portion = None + delta = None + + text_portion = None + + # set cursors and state appropriately + self.current_tool_id += 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("Starting on a new tool %s", self.current_tool_id) + + # case -- we're updating an existing tool call + elif (cur_tool_start_count > cur_tool_end_count + and cur_tool_start_count == prev_tool_start_count): + + # get the portion of the text that's the tool call + tool_call_portion = current_text.split( + self.tool_call_start_token)[-1] + text_portion = None + + # case -- the current tool call is being closed. + elif (cur_tool_start_count == cur_tool_end_count + and cur_tool_end_count > prev_tool_end_count): + diff = self.prev_tool_call_arr[self.current_tool_id].get( + "arguments") + if diff: + diff = json.dumps(diff).replace( + self.streamed_args_for_tool[self.current_tool_id], "") + logger.debug( + "Finishing tool and found diff that had not " + "been streamed yet: %s", diff) + self.streamed_args_for_tool[self.current_tool_id] \ + += diff + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) + + # case -- otherwise we're just generating text + else: + text = delta_text.replace(self.tool_call_start_token, "") + text = text.replace(self.tool_call_end_token, "") + delta = DeltaMessage(tool_calls=[], content=text) + return delta + + try: + + current_tool_call = partial_json_parser.loads( + tool_call_portion or "{}", + flags) if tool_call_portion else None + logger.debug("Parsed tool call %s", current_tool_call) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # case - we haven't sent the tool name yet. If it's available, send + # it. otherwise, wait until it's available. + if not self.current_tool_name_sent: + function_name: Union[str, None] = current_tool_call.get("name") + if function_name: + self.current_tool_name_sent = True + return DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + else: + return None + # case -- otherwise, send the tool call delta + + # if the tool call portion is None, send the delta as text + if tool_call_portion is None: + # if there's text but not tool calls, send that - + # otherwise None to skip chunk + delta = DeltaMessage(content=delta_text) \ + if text_portion is not None else None + return delta + + # now, the nitty-gritty of tool calls + # now we have the portion to parse as tool call. + + logger.debug("Trying to parse current tool call with ID %s", + self.current_tool_id) + + # if we're starting a new tool call, push an empty object in as + # a placeholder for the arguments + if len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + + # main logic for tool parsing here - compare prev. partially-parsed + # JSON to the current partially-parsed JSON + prev_arguments = ( + self.prev_tool_call_arr[self.current_tool_id].get("arguments")) + cur_arguments = current_tool_call.get("arguments") + + logger.debug("diffing old arguments: %s", prev_arguments) + logger.debug("against new ones: %s", cur_arguments) + + # case -- no arguments have been created yet. skip sending a delta. + if not cur_arguments and not prev_arguments: + logger.debug("Skipping text %s - no arguments", delta_text) + delta = None + + # case -- prev arguments are defined, but non are now. + # probably impossible, but not a fatal error - just keep going + elif not cur_arguments and prev_arguments: + logger.error("should be impossible to have arguments reset " + "mid-call. skipping streaming anything.") + delta = None + + # case -- we now have the first info about arguments available from + # autocompleting the JSON + elif cur_arguments and not prev_arguments: + + cur_arguments_json = json.dumps(cur_arguments) + logger.debug("finding %s in %s", delta_text, + cur_arguments_json) + + # get the location where previous args differ from current + args_delta_start_loc = cur_arguments_json.index(delta_text) \ + + len(delta_text) + + # use that to find the actual delta + arguments_delta = cur_arguments_json[:args_delta_start_loc] + logger.debug("First tokens in arguments received: %s", + arguments_delta) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[self.current_tool_id] \ + += arguments_delta + + # last case -- we have an update to existing arguments. + elif cur_arguments and prev_arguments: + + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + logger.debug("Searching for diff between\n%s", cur_args_json) + logger.debug("and\n%s", prev_args_json) + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + logger.debug("got argument diff %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[self.current_tool_id] \ + += argument_diff + + # handle saving the state for the current tool into + # the "prev" list for use in diffing for the next iteration + if self.current_tool_id == len(self.prev_tool_call_arr) - 1: + self.prev_tool_call_arr[self.current_tool_id] = \ + current_tool_call + else: + self.prev_tool_call_arr.append(current_tool_call) + + return delta + + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", e) + return None # do not stream a delta. skip this token ID. diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py new file mode 100644 index 0000000000000..4b0e1c91df97c --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -0,0 +1,282 @@ +import json +import re +from typing import Dict, List, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +class MistralToolParser(ToolParser): + """ + Tool call parser for Mistral 7B Instruct v0.3, intended for use with the + examples/tool_chat_template_mistral.jinja template. + + Used when --enable-auto-tool-choice --tool-call-parser mistral are all set + """ + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + if isinstance(self.model_tokenizer, MistralTokenizer): + self.model_tokenizer = self.model_tokenizer.tokenizer + else: + logger.info("Non-Mistral tokenizer detected when using a Mistral " + "model...") + + # initialize properties used for state when parsing tool calls in + # streaming mode + self.prev_tool_call_arr: List[Dict] = [] + self.current_tool_id: int = -1 + self.current_tool_name_sent: bool = False + self.streamed_args_for_tool: List[str] = [ + ] # map what has been streamed for each tool so far to a list + self.bot_token = "[TOOL_CALLS]" + self.bot_token_id = self.model_tokenizer.vocab[self.bot_token] + self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) + + def extract_tool_calls(self, + model_output: str) -> ExtractedToolCallInformation: + """ + Extract the tool calls from a complete model response. Requires + find-and-replacing single quotes with double quotes for JSON parsing, + make sure your tool call arguments don't ever include quotes! + """ + + # case -- if a tool call token is not present, return a text response + if self.bot_token not in model_output: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + try: + + # use a regex to find the tool call. remove the BOT token + # and make sure to replace single quotes with double quotes + raw_tool_call = self.tool_call_regex.findall( + model_output.replace(self.bot_token, ""))[0] + + # load the JSON, and then use it to build the Function and + # Tool Call + function_call_arr = json.loads(raw_tool_call) + tool_calls: List[ToolCall] = [ + ToolCall( + type="function", + function=FunctionCall( + name=raw_function_call["name"], + # function call args are JSON but as a string + arguments=json.dumps(raw_function_call["arguments"]))) + for raw_function_call in function_call_arr + ] + + # get any content before the tool call + content = model_output.split(self.bot_token)[0] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=content if len(content) > 0 else None) + + except Exception as e: + logger.error("Error in extracting tool call from response: %s", e) + # return information to just treat the tool call as regular JSON + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=model_output) + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> Union[DeltaMessage, None]: + + # if the tool call token is not in the tokens generated so far, append + # output to contents since it's not a tool + if self.bot_token not in current_text: + return DeltaMessage(content=delta_text) + + # if the tool call token ID IS in the tokens generated so far, that + # means we're parsing as tool calls now + + # handle if we detected the BOT token which means the start of tool + # calling + if (self.bot_token_id in delta_token_ids + and len(delta_token_ids) == 1): + # if it's the only token, return None, so we don't send a chat + # completion any don't send a control token + return None + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + try: + + # replace BOT token with empty string, and convert single quotes + # to double to allow parsing as JSON since mistral uses single + # quotes instead of double for tool calls + parsable_arr = current_text.split(self.bot_token)[-1] + + # tool calls are generated in an array, so do partial JSON + # parsing on the entire array + try: + tool_call_arr: List[Dict] = partial_json_parser.loads( + parsable_arr, flags) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # select as the current tool call the one we're on the state at + + current_tool_call: Dict = tool_call_arr[self.current_tool_id] \ + if len(tool_call_arr) > 0 else {} + + # case -- if no tokens have been streamed for the tool, e.g. + # only the array brackets, stream nothing + if len(tool_call_arr) == 0: + return None + + # case: we are starting a new tool in the array + # -> array has > 0 length AND length has moved past cursor + elif (len(tool_call_arr) > 0 + and len(tool_call_arr) > self.current_tool_id + 1): + + # if we're moving on to a new call, first make sure we + # haven't missed anything in the previous one that was + # auto-generated due to JSON completions, but wasn't + # streamed to the client yet. + if self.current_tool_id >= 0: + diff: Union[str, None] = current_tool_call.get("arguments") + + if diff: + diff = json.dumps(diff).replace( + self.streamed_args_for_tool[self.current_tool_id], + "") + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += diff + else: + delta = None + else: + delta = None + # re-set stuff pertaining to progress in the current tool + self.current_tool_id = len(tool_call_arr) - 1 + self.current_tool_name_sent = False + self.streamed_args_for_tool.append("") + logger.debug("starting on new tool %d", self.current_tool_id) + return delta + + # case: update an existing tool - this is handled below + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + if not self.current_tool_name_sent: + function_name = current_tool_call.get("name") + if function_name: + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + else: + delta = None + + # now we know we're on the same tool call and we're streaming + # arguments + else: + + prev_arguments = self.prev_tool_call_arr[ + self.current_tool_id].get("arguments") + cur_arguments = current_tool_call.get("arguments") + + new_text = delta_text.replace("\'", "\"") + + if not cur_arguments and not prev_arguments: + + delta = None + elif not cur_arguments and prev_arguments: + logger.error( + "INVARIANT - impossible to have arguments reset " + "mid-arguments") + delta = None + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments) + logger.debug("finding %s in %s", new_text, + cur_arguments_json) + + arguments_delta = cur_arguments_json[:cur_arguments_json. + index(new_text) + + len(new_text)] + logger.debug("First tokens in arguments received: %s", + arguments_delta) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + logger.debug("Searching for diff between \n%s\n%s", + cur_args_json, prev_args_json) + + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + logger.debug("got arguments diff: %s", argument_diff) + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + else: + # try parsing it with regular JSON - if it works we're + # at the end, and we need to send the difference between + # tokens streamed so far and the valid JSON + delta = None + + # check to see if the name is defined and has been sent. if so, + # stream the name - otherwise keep waiting + # finish by setting old and returning None as base case + self.prev_tool_call_arr = tool_call_arr + return delta + + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", e) + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None diff --git a/vllm/entrypoints/openai/tool_parsers/utils.py b/vllm/entrypoints/openai/tool_parsers/utils.py new file mode 100644 index 0000000000000..db7fc5259fc4e --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/utils.py @@ -0,0 +1,87 @@ +def find_common_prefix(s1: str, s2: str) -> str: + """ + Finds a common prefix that is shared between two strings, if there is one. + Order of arguments is NOT important. + + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. + + e.g. find_common_prefix('{"fruit": "ap"}', '{"fruit": "apple"}') -> + '{"fruit": "ap' + """ + prefix = '' + min_length = min(len(s1), len(s2)) + for i in range(0, min_length): + if s1[i] == s2[i]: + prefix += s1[i] + else: + break + return prefix + + +def find_common_suffix(s1: str, s2: str) -> str: + """ + Finds a common suffix shared between two strings, if there is one. Order of + arguments is NOT important. + Stops when the suffix ends OR it hits an alphanumeric character + + e.g. find_common_suffix('{"fruit": "ap"}', '{"fruit": "apple"}') -> '"}' + """ + suffix = '' + min_length = min(len(s1), len(s2)) + for i in range(1, min_length + 1): + if s1[-i] == s2[-i] and not s1[-i].isalnum(): + suffix = s1[-i] + suffix + else: + break + return suffix + + +def extract_intermediate_diff(curr: str, old: str) -> str: + """ + Given two strings, extract the difference in the middle between two strings + that are known to have a common prefix and/or suffix. + + This function is provided as a UTILITY for extracting information from JSON + generated by partial_json_parser, to help in ensuring that the right tokens + are returned in streaming, so that close-quotes, close-brackets and + close-braces are not returned prematurely. The order of arguments IS + important - the new version of the partially-parsed JSON must be the first + argument, and the secnod argument must be from the previous generation. + + What it returns, is tokens that should be streamed to the client. + + e.g. extract_intermediate_diff('{"fruit": "apple"}', '{"fruit": "ap"}') + -> 'ple' + + """ + suffix = find_common_suffix(curr, old) + + old = old[::-1].replace(suffix[::-1], '', 1)[::-1] + prefix = find_common_prefix(curr, old) + diff = curr + if len(suffix): + diff = diff[::-1].replace(suffix[::-1], '', 1)[::-1] + + if len(prefix): + # replace the prefix only once in case it's mirrored + diff = diff.replace(prefix, '', 1) + + return diff + + +def find_all_indices(string, substring): + """ + Find all (starting) indices of a substring in a given string. Useful for + tool call extraction + """ + indices = [] + index = -1 + while True: + index = string.find(substring, index + 1) + if index == -1: + break + indices.append(index) + return indices diff --git a/vllm/envs.py b/vllm/envs.py index 24e09ee0e055f..ed45047e9f8fc 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -31,6 +31,7 @@ VLLM_TRACE_FUNCTION: int = 0 VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_USE_FLASHINFER_SAMPLER: bool = False + VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_OMP_THREADS_BIND: str = "" @@ -60,6 +61,7 @@ VLLM_ALLOW_ENGINE_USE_RAY: bool = False VLLM_PLUGINS: Optional[List[str]] = None VLLM_TORCH_PROFILER_DIR: Optional[str] = None + VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False def get_default_cache_root(): @@ -196,6 +198,10 @@ def get_default_config_root(): # Internal flag to enable Dynamo graph capture "VLLM_TEST_DYNAMO_GRAPH_CAPTURE": lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), + "VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": + lambda: + (os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in + ("true", "1")), # local rank of the process in the distributed setting, used to determine # the GPU device id @@ -348,7 +354,7 @@ def get_default_config_root(): os.path.join(get_default_cache_root(), "vllm", "xla_cache"), )), "VLLM_FUSED_MOE_CHUNK_SIZE": - lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "65536")), + lambda: int(os.getenv("VLLM_FUSED_MOE_CHUNK_SIZE", "32768")), # If set, vllm will skip the deprecation warnings. "VLLM_NO_DEPRECATION_WARNING": @@ -400,6 +406,16 @@ def get_default_config_root(): "VLLM_TORCH_PROFILER_DIR": lambda: (None if os.getenv("VLLM_TORCH_PROFILER_DIR", None) is None else os .path.expanduser(os.getenv("VLLM_TORCH_PROFILER_DIR", "."))), + + # If set, vLLM will use Triton implementations of AWQ. + "VLLM_USE_TRITON_AWQ": + lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))), + + # If set, allow loading or unloading lora adapters in runtime, + "VLLM_ALLOW_RUNTIME_LORA_UPDATING": + lambda: + (os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in + ("1", "true")), } # end-env-vars-definition diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 37d12725bd1e4..ec9b24ce1318f 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -11,8 +11,9 @@ ResultHandler, WorkerMonitor) from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import (GiB_bytes, get_distributed_init_method, get_open_port, get_vllm_instance_id, make_async) from vllm.worker.worker_base import WorkerWrapperBase @@ -295,6 +296,12 @@ def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: for result in parallel_worker_tasks: result.get() + def start_profile(self) -> None: + self.driver_method_invoker(self.driver_worker, "start_profile") + + def stop_profile(self) -> None: + self.driver_method_invoker(self.driver_worker, "stop_profile") + class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase): diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index 1a35a7c3b8f75..ad84422ee2129 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -6,7 +6,8 @@ from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest logger = init_logger(__name__) diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 422bef107f352..c96cb0f2c2981 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -6,8 +6,9 @@ PromptAdapterConfig, SchedulerConfig, SpeculativeConfig) from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.sequence import ExecuteModelRequest class ExecutorBase(ABC): diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 795692195f84d..2185c9cf6cead 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -3,8 +3,9 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.prompt_adapter.request import PromptAdapterRequest -from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput +from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase @@ -168,6 +169,12 @@ def check_health(self) -> None: # it's running. return + def start_profile(self) -> None: + self.driver_worker.start_profile() + + def stop_profile(self) -> None: + self.driver_worker.stop_profile() + class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 7b98fbea5cd0a..9c6d4051eb3f8 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -14,7 +14,8 @@ from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, ResultHandler, WorkerMonitor) from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.triton_utils import maybe_set_triton_cache_manager from vllm.utils import (_run_task_with_lock, cuda_device_count_stateless, get_distributed_init_method, get_open_port, @@ -126,7 +127,7 @@ def shutdown(signum, frame): max_parallel_loading_workers) def _check_executor_parameters(self): - world_size = self.parallel_config.tensor_parallel_size + world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 02627de3e0be7..f2fcfa58b26e1 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -3,7 +3,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py index 867859d8d3d79..78606e223aa7b 100644 --- a/vllm/executor/openvino_executor.py +++ b/vllm/executor/openvino_executor.py @@ -9,7 +9,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip, get_open_port, make_async) diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 760c06cb6c06f..1359a0d310a70 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -12,7 +12,8 @@ from vllm.executor.msgspec_utils import encode_hook from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import (_run_task_with_lock, get_distributed_init_method, get_ip, get_open_port, get_vllm_instance_id, make_async) @@ -426,18 +427,34 @@ def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: async_run_remote_workers_only to complete.""" ray.get(parallel_worker_tasks) - def _compiled_ray_dag(self, enable_asyncio: bool): + def _check_ray_adag_installation(self): import pkg_resources from packaging import version - required_version = version.parse("2.32") + required_version = version.parse("2.35") current_version = version.parse( pkg_resources.get_distribution("ray").version) if current_version < required_version: raise ValueError(f"Ray version {required_version} or greater is " f"required, but found {current_version}") + import importlib.util + adag_spec = importlib.util.find_spec( + "ray.experimental.compiled_dag_ref") + if adag_spec is None: + raise ValueError("Ray accelerated DAG is not installed. " + "Run `pip install ray[adag]` to install it.") + + cupy_spec = importlib.util.find_spec("cupy") + if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: + raise ValueError( + "cupy is not installed but required since " + "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set." + "Run `pip install ray[adag]` and check cupy installation.") + + def _compiled_ray_dag(self, enable_asyncio: bool): assert self.parallel_config.use_ray + self._check_ray_adag_installation() from ray.dag import InputNode, MultiOutputNode from ray.experimental.channel.torch_tensor_type import TorchTensorType diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index 7048d47980723..8c8b5f741488b 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -10,7 +10,8 @@ from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.executor.tpu_executor import TPUExecutor from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, get_vllm_instance_id, make_async) @@ -70,6 +71,19 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", worker_module_name = "vllm.worker.tpu_worker" worker_class_name = "TPUWorker" + # GKE does not fetch environment information from metadata server + # and instead sets these from within the Ray process. Therefore we + # need to override the Ray environment variables manually. + override_env = {} + if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ: + override_env.update({ + "TPU_CHIPS_PER_HOST_BOUNDS": + os.environ["TPU_CHIPS_PER_HOST_BOUNDS"] + }) + if "TPU_HOST_BOUNDS" in os.environ: + override_env.update( + {"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]}) + worker = ray.remote( num_cpus=0, resources={"TPU": 1}, @@ -80,6 +94,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", worker_class_name=worker_class_name, trust_remote_code=self.model_config.trust_remote_code, ) + if override_env: + worker.override_env_vars.remote(override_env) worker_ip = ray.get(worker.get_node_ip.remote()) if worker_ip == driver_ip and self.driver_dummy_worker is None: @@ -95,12 +111,40 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # Else, added to the list of workers. self.workers.append(worker) + logger.debug("workers: %s", self.workers) + logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker) if self.driver_dummy_worker is None: raise ValueError( "Ray does not allocate any TPUs on the driver node. Consider " "adjusting the Ray placement group or running the driver on a " "TPU node.") + worker_ips = [ + ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined] + for worker in self.workers + ] + ip_counts: Dict[str, int] = {} + for ip in worker_ips: + ip_counts[ip] = ip_counts.get(ip, 0) + 1 + + def sort_by_driver_then_worker_ip(worker): + """ + Sort the workers based on 3 properties: + 1. If the worker is on the same node as the driver (vllm engine), + it should be placed first. + 2. Then, if the worker is on a node with fewer workers, it should + be placed first. + 3. Finally, if the work is on a node with smaller IP address, it + should be placed first. + """ + ip = ray.get(worker.get_node_ip.remote()) + return (ip != driver_ip, ip_counts[ip], ip) + + # After sorting, the workers on the same node will be + # close to each other, and the workers on the driver + # node will be placed first. + self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) + # Get the set of TPU IDs used on each node. worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", use_dummy_driver=True) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index bfdd0f5cf97b3..59e9854393b6b 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -1,3 +1,4 @@ +import os import time from collections import defaultdict from typing import Dict, List, Optional, Tuple, Union @@ -84,6 +85,9 @@ def execute_model_spmd( return output + def override_env_vars(self, vars: Dict[str, str]): + os.environ.update(vars) + ray_import_err = None except ImportError as e: @@ -291,3 +295,28 @@ def initialize_ray_cluster( _verify_bundles(current_placement_group, parallel_config, device_str) # Set the placement group in the parallel config parallel_config.placement_group = current_placement_group + + +def get_num_tpu_nodes() -> int: + from ray._private.accelerators import TPUAcceleratorManager + cluster_resources = ray.cluster_resources() + total_tpus = int(cluster_resources["TPU"]) + tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators() + assert total_tpus % tpus_per_node == 0 + return total_tpus // tpus_per_node + + +def get_num_nodes_in_placement_group() -> int: + pg_table = ray.util.placement_group_table() + current_pg = ray.util.get_current_placement_group() + num_nodes = 0 + + if current_pg: + nodes_in_pg = set() + for pg_key, pg in pg_table.items(): + if pg_key == current_pg.id.hex(): + for _, node in pg["bundles_to_node_id"].items(): + nodes_in_pg.add(node) + num_nodes = len(nodes_in_pg) + + return num_nodes diff --git a/vllm/executor/tpu_executor.py b/vllm/executor/tpu_executor.py index 253c8abdc1ada..0af8ba41e24d5 100644 --- a/vllm/executor/tpu_executor.py +++ b/vllm/executor/tpu_executor.py @@ -5,7 +5,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py index 774204dd4612a..bada56068507a 100644 --- a/vllm/executor/xpu_executor.py +++ b/vllm/executor/xpu_executor.py @@ -9,7 +9,8 @@ from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest, PoolerOutput, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.utils import make_async from vllm.worker.worker_base import WorkerBase diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index a8ea67991a375..b9ac498b23a7b 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -39,7 +39,7 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device: # unquantizedLinear if hasattr(base_layer, "weight"): return base_layer.weight.device - # GPTQ/AWQ/SqueezeLLM + # GPTQ/AWQ elif hasattr(base_layer, "qweight"): return base_layer.qweight.device # marlin diff --git a/vllm/lora/ops/bgmv_expand.py b/vllm/lora/ops/bgmv_expand.py index 0bbc1844ef455..619408b9315cf 100644 --- a/vllm/lora/ops/bgmv_expand.py +++ b/vllm/lora/ops/bgmv_expand.py @@ -160,6 +160,9 @@ def _bgmv_expand( return -bgmv_expand = torch.library.custom_op("lora::bgmv_expand", - _bgmv_expand, - mutates_args=["output_tensor"]) +try: + bgmv_expand = torch.library.custom_op("lora::bgmv_expand", + _bgmv_expand, + mutates_args=["output_tensor"]) +except AttributeError: + bgmv_expand = _bgmv_expand diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/bgmv_expand_slice.py index 87d7d9902a4c1..c16db233891a5 100644 --- a/vllm/lora/ops/bgmv_expand_slice.py +++ b/vllm/lora/ops/bgmv_expand_slice.py @@ -173,6 +173,9 @@ def _bgmv_expand_slice( return -bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice", - _bgmv_expand_slice, - mutates_args=["output_tensor"]) +try: + bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice", + _bgmv_expand_slice, + mutates_args=["output_tensor"]) +except AttributeError: + bgmv_expand_slice = _bgmv_expand_slice diff --git a/vllm/lora/ops/bgmv_shrink.py b/vllm/lora/ops/bgmv_shrink.py index c979d758492db..0846ff36b1692 100644 --- a/vllm/lora/ops/bgmv_shrink.py +++ b/vllm/lora/ops/bgmv_shrink.py @@ -142,6 +142,9 @@ def _bgmv_shrink( return -bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink", - _bgmv_shrink, - mutates_args=["output_tensor"]) +try: + bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink", + _bgmv_shrink, + mutates_args=["output_tensor"]) +except AttributeError: + bgmv_shrink = _bgmv_shrink diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 80a0b605b0fe2..c71332d8bdfb2 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -192,6 +192,9 @@ def _sgmv_expand( return -sgmv_expand = torch.library.custom_op("lora::sgmv_expand", - _sgmv_expand, - mutates_args=["output_tensor"]) +try: + sgmv_expand = torch.library.custom_op("lora::sgmv_expand", + _sgmv_expand, + mutates_args=["output_tensor"]) +except AttributeError: + sgmv_expand = _sgmv_expand diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py index 53237166a1c68..b4ae9a2acbb5c 100644 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -205,6 +205,9 @@ def _sgmv_expand_slice( return -sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice", - _sgmv_expand_slice, - mutates_args=["output_tensor"]) +try: + sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice", + _sgmv_expand_slice, + mutates_args=["output_tensor"]) +except AttributeError: + sgmv_expand_slice = _sgmv_expand_slice diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index 51d2a09eee94b..c0791c260e915 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -189,6 +189,9 @@ def _sgmv_shrink( return -sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink", - _sgmv_shrink, - mutates_args=["output_tensor"]) +try: + sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink", + _sgmv_shrink, + mutates_args=["output_tensor"]) +except AttributeError: + sgmv_shrink = _sgmv_shrink diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index d666fc293757b..6d5c834299961 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -10,10 +10,8 @@ import torch from vllm.triton_utils import HAS_TRITON -from vllm.utils import is_xpu -# FIXME: xpu path doesn't support torch.library.custom_op -if HAS_TRITON and not is_xpu(): +if HAS_TRITON: from vllm.lora.ops.bgmv_expand import bgmv_expand from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice from vllm.lora.ops.bgmv_shrink import bgmv_shrink diff --git a/vllm/lora/request.py b/vllm/lora/request.py index d770da4f2407d..47a59d80d3a45 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -28,7 +28,6 @@ class LoRARequest( lora_path: str = "" lora_local_path: Optional[str] = msgspec.field(default=None) long_lora_max_len: Optional[int] = None - __hash__ = AdapterRequest.__hash__ def __post_init__(self): if 'lora_local_path' in self.__struct_fields__: @@ -75,3 +74,21 @@ def local_path(self, value): DeprecationWarning, stacklevel=2) self.lora_path = value + + def __eq__(self, value: object) -> bool: + """ + Overrides the equality method to compare LoRARequest + instances based on lora_name. This allows for identification + and comparison lora adapter across engines. + """ + return isinstance(value, + self.__class__) and self.lora_name == value.lora_name + + def __hash__(self) -> int: + """ + Overrides the hash method to hash LoRARequest instances + based on lora_name. This ensures that LoRARequest instances + can be used in hash-based collections such as sets and dictionaries, + identified by their names across engines. + """ + return hash(self.lora_name) diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index f9fcdead980a2..7161e83952a3d 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -59,8 +59,9 @@ def _adapt_request_for_tool_use(request: Union[CompletionRequest, if type(request) is CompletionRequest: return request - # user has chosen to not use any tool - if request.tool_choice == "none": + # user has chosen to not use any tool, + # OR is allowing the model to choose a tool. + if request.tool_choice == "none" or request.tool_choice == "auto": return request # user has chosen to use a named tool diff --git a/vllm/model_executor/guided_decoding/outlines_decoding.py b/vllm/model_executor/guided_decoding/outlines_decoding.py index bfc658ef7d26b..e1f5b380120c5 100644 --- a/vllm/model_executor/guided_decoding/outlines_decoding.py +++ b/vllm/model_executor/guided_decoding/outlines_decoding.py @@ -8,8 +8,9 @@ from pydantic import BaseModel from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, - CompletionRequest) +from vllm.entrypoints.openai.protocol import ( + ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, + CompletionRequest) from vllm.model_executor.guided_decoding.guided_fields import ( GuidedDecodingRequest) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( @@ -101,16 +102,30 @@ def _get_guide_and_mode( request: Union[CompletionRequest, ChatCompletionRequest, GuidedDecodingRequest] ) -> Union[Tuple[str, GuidedDecodingMode], Tuple[None, None]]: + # if the request is a chat completion request, AND the tool choice is a + # named tool choice, do guided decoding + # using that tool as the JSON schema + if isinstance(request, ChatCompletionRequest) and isinstance( + request.tool_choice, ChatCompletionNamedToolChoiceParam): + # Guided generation for tools/functions parameters + if request.tool_choice.type == "function": + for tool in request.tools: + if (tool.type == "function" and tool.function.name + == request.tool_choice.function.name): + json = json_dumps(tool.function.parameters, sort_keys=True) + return json, GuidedDecodingMode.JSON + return None, None - if request.guided_json: - json = request.guided_json - if isinstance(json, dict): + elif request.guided_json: + if isinstance(request.guided_json, dict): # turn dict into hashable string - json = json_dumps(json) - elif isinstance(json, BaseModel): + json = json_dumps(request.guided_json) + elif isinstance(request.guided_json, BaseModel): # use pydantic signature so that different model classes # with the same fields will get hashed the same - json = str(json.__signature__) + json = str(request.guided_json.__signature__) + else: + json = request.guided_json return json, GuidedDecodingMode.JSON elif request.guided_regex: return request.guided_regex, GuidedDecodingMode.REGEX diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index fd6f41b90042e..e9b5703ca28be 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -2,16 +2,22 @@ FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.triton_utils import HAS_TRITON -__all__ = ["FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported"] +__all__ = [ + "FusedMoE", + "FusedMoEMethodBase", + "FusedMoeWeightScaleSupported", +] if HAS_TRITON: - + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe, single_marlin_moe) from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts, fused_marlin_moe, fused_moe, fused_topk, - get_config_file_name, grouped_topk) + fused_experts, fused_moe, fused_topk, get_config_file_name, + grouped_topk) __all__ += [ "fused_marlin_moe", + "single_marlin_moe", "fused_moe", "fused_topk", "fused_experts", diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..cd0cdbea0c337 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,130 @@ +{ + "3328": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "768": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "1792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "2560": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2816": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3584": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "3840": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "1280": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "2304": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..ba9041d008507 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,130 @@ +{ + "3840": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "1792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "3584": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2816": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 4 + }, + "1280": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 2 + }, + "768": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "3328": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "2560": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "2304": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json new file mode 100644 index 0000000000000..57055453aa24c --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json @@ -0,0 +1,130 @@ +{ + "2048": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "1792": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 4 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "3328": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 8, + "num_stages": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2560": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "num_warps": 4, + "num_stages": 4 + }, + "768": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 2 + }, + "2816": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4096": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "2304": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 8, + "num_warps": 8, + "num_stages": 2 + }, + "1280": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "3840": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "3584": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + } +} \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py new file mode 100644 index 0000000000000..200a6148978aa --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -0,0 +1,219 @@ +"""Fused MoE utilities for GPTQ.""" +import functools +from typing import Any, Dict, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_topk, moe_align_block_size, try_get_optimal_moe_config) + + +def single_marlin_moe( + hidden_states: torch.Tensor, + w: torch.Tensor, + scales: torch.Tensor, + gating_output: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None) -> torch.Tensor: + """ + This function computes the multiplication of hidden_states with expert + weights used in Marlin MoE, using weights w and top-k gating mechanism. + Its purpose is testing and debugging the fused MoE kernel. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the Marlin Mul. + - w (torch.Tensor): The set of expert weights. + - scales (torch.Tensor): The quantization scales. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - g_idx (torch.Tensor): The act_order indices. + - perm (torch.Tensor): The act_order input permutation. + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + assert hidden_states.shape[1] == w.shape[1] * 16, "Hidden size mismatch" + assert gating_output.shape[1] == w.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w.is_contiguous(), "Expert weights must be contiguous" + assert hidden_states.dtype == torch.float16 + + M, K = hidden_states.shape + E = w.shape[0] + N = w.shape[2] // 2 + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + + # This might not be an optimal config for a single MMM + get_config_func = functools.partial(try_get_optimal_moe_config, + w.shape, + w.shape, + topk_ids.shape[1], + None, + override_config=override_config, + is_marlin=True) + config = get_config_func(M) + + block_size_m = config['BLOCK_SIZE_M'] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = (N // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) + + intermediate_cache = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, w, sorted_token_ids, topk_weights, topk_ids, scales, + g_idx, perm, workspace, M, N, K, True, E, topk, block_size_m, True, + False) + + return torch.sum(intermediate_cache.view(*intermediate_cache.shape), dim=1) + + +def fused_marlin_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + g_idx1: torch.Tensor, + g_idx2: torch.Tensor, + perm1: torch.Tensor, + perm2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + override_config: Optional[Dict[str, Any]] = None, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - g_idx1 (torch.Tensor): The first set of act_order indices. + - g_idx2 (torch.Tensor): The second set of act_order indices. + - perm1 (torch.Tensor): The first act_order input permutation. + - perm2 (torch.Tensor): The second act_order input permutation. + - topk_weights (torch.Tensor): Top-k weights. + - topk_ids (torch.Tensor): Indices of topk-k elements. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[ + 0], "Number of tokens mismatch" + assert hidden_states.shape[ + 1] == w1.shape[1] * 16, "Hidden size mismatch w1" + assert hidden_states.shape[ + 1] == w2.shape[2] // 2, "Hidden size mismatch w2" + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype == torch.float16 + + M, K = hidden_states.shape + E = w1.shape[0] + N = w2.shape[1] * 16 + topk = topk_ids.shape[1] + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + w2.shape, + topk_ids.shape[1], + None, + override_config=override_config, + is_marlin=True, + ) + config = get_config_func(M) + + block_size_m = config["BLOCK_SIZE_M"] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) + + intermediate_cache2 = torch.empty( + (M * topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + + intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, + w1, + sorted_token_ids, + topk_weights, + topk_ids, + w1_scale, + g_idx1, + perm1, + workspace, + M, + 2 * N, + K, + True, + E, + topk, + block_size_m, + True, + False, + ) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) + + intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( + intermediate_cache2, + w2, + sorted_token_ids, + topk_weights, + topk_ids, + w2_scale, + g_idx2, + perm2, + workspace, + M, + K, + N, + True, + E, + topk, + block_size_m, + False, + True, + ) + + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index d2b152320e11e..bd13d8fecbb96 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -2,7 +2,7 @@ import functools import json import os -from typing import Any, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple import torch import triton @@ -323,15 +323,22 @@ def get_moe_configs(E: int, N: int, return None -def get_default_config(M: int, E: int, N: int, K: int, topk: int, - dtype: Optional[str], - is_marlin: bool) -> Dict[str, int]: +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], + is_marlin: bool, +) -> Dict[str, int]: config = { 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 } + # A heuristic: fused marlin works faster with this config for small M if M <= E or (is_marlin and M <= 32): config = { 'BLOCK_SIZE_M': 16, @@ -342,14 +349,15 @@ def get_default_config(M: int, E: int, N: int, K: int, topk: int, return config -def try_get_optimal_moe_config(w1_shape: Tuple[int, ...], - w2_shape: Tuple[int, ...], - top_k: int, - dtype: Optional[str], - M: int, - override_config: Optional[Dict[str, - Any]] = None, - is_marlin: bool = False): +def try_get_optimal_moe_config( + w1_shape: Tuple[int, ...], + w2_shape: Tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + override_config: Optional[Dict[str, Any]] = None, + is_marlin: bool = False, +): if override_config: config = override_config else: @@ -391,6 +399,7 @@ def fused_topk( topk, dtype=torch.int32, device=hidden_states.device) + ops.topk_softmax( topk_weights, topk_ids, @@ -437,108 +446,6 @@ def grouped_topk(hidden_states: torch.Tensor, return topk_weights, topk_ids -def fused_marlin_moe(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - gating_output: torch.Tensor, - g_idx1: torch.Tensor, - g_idx2: torch.Tensor, - rand_perm1: torch.Tensor, - rand_perm2: torch.Tensor, - topk: int, - renormalize: bool, - override_config: Optional[Dict[str, Any]] = None, - use_fp8: bool = False, - w1_scale: Optional[torch.Tensor] = None, - w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: - """ - This function computes a Mixture of Experts (MoE) layer using two sets of - weights, w1 and w2, and top-k gating mechanism. - Parameters: - - hidden_states (torch.Tensor): The input tensor to the MoE layer. - - w1 (torch.Tensor): The first set of expert weights. - - w2 (torch.Tensor): The second set of expert weights. - - gating_output (torch.Tensor): The output of the gating operation - (before softmax). - - topk (int): The number of top-k experts to select. - - renormalize (bool): If True, renormalize the top-k weights to sum to 1. - - inplace (bool): If True, perform the operation in-place. - Defaults to False. - - override_config (Optional[Dict[str, Any]]): Optional override - for the kernel configuration. - - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner - products for w1 and w2. Defaults to False. - - w1_scale (Optional[torch.Tensor]): Optional scale to be used for - w1. - - w2_scale (Optional[torch.Tensor]): Optional scale to be used for - w2. - Returns: - - torch.Tensor: The output tensor after applying the MoE layer. - """ - # Check constraints. - assert hidden_states.shape[0] == gating_output.shape[0], ( - "Number of tokens mismatch") - assert hidden_states.shape[ - 1] == w1.shape[1] * 16, "Hidden size mismatch w1" - assert hidden_states.shape[ - 1] == w2.shape[2] // 2, "Hidden size mismatch w2" - assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - assert hidden_states.dtype in [ - torch.float32, torch.float16, torch.bfloat16 - ] - - #TODO fp8 is not implemented yet - assert not use_fp8 - - M, K = hidden_states.shape - E = w1.shape[0] - N = w2.shape[1] * 16 - - topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, - renormalize) - - get_config_func = functools.partial(try_get_optimal_moe_config, - w1.shape, - w2.shape, - topk_ids.shape[1], - "float8" if use_fp8 else None, - override_config=override_config, - is_marlin=True) - config = get_config_func(M) - - block_size_m = config['BLOCK_SIZE_M'] - - sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) - - max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 - workspace = torch.zeros(max_workspace_size, - dtype=torch.int, - device="cuda", - requires_grad=False) - - intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N), - device=hidden_states.device, - dtype=hidden_states.dtype) - - intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( - hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale, - g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk, - block_size_m, True, False) - - ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) - - intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( - intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids, - w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk, - block_size_m, False, True) - - return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), - dim=1) - - def get_config_dtype_str(dtype: torch.dtype, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False): @@ -695,6 +602,7 @@ def fused_moe( use_grouped_topk: bool = False, num_expert_group: Optional[int] = None, topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, use_fp8_w8a8: bool = False, use_int8_w8a16: bool = False, w1_scale: Optional[torch.Tensor] = None, @@ -742,9 +650,12 @@ def fused_moe( topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, topk, renormalize, num_expert_group, topk_group) - else: + elif custom_routing_function is None: topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, renormalize) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states, gating_output, topk, renormalize) return fused_experts(hidden_states, w1, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 61ebef5e11f43..f6c6f5f529408 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,6 +1,6 @@ from abc import abstractmethod from enum import Enum -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple import torch @@ -62,15 +62,18 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None + ) -> torch.Tensor: return self.forward(x=x, layer=layer, @@ -79,17 +82,21 @@ def apply(self, renormalize=renormalize, use_grouped_topk=use_grouped_topk, topk_group=topk_group, - num_expert_group=num_expert_group) - - def forward_cuda(self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None) -> torch.Tensor: + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + def forward_cuda( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None + ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_experts) @@ -101,7 +108,8 @@ def forward_cuda(self, top_k=top_k, renormalize=renormalize, topk_group=topk_group, - num_expert_group=num_expert_group) + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) return fused_experts(hidden_states=x, w1=layer.w13_weight, @@ -114,20 +122,24 @@ def forward_cpu(self, *args, **kwargs): raise NotImplementedError( "The CPU backend currently does not support MoE.") - def forward_tpu(self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None) -> torch.Tensor: + def forward_tpu( + self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None + ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe.moe_pallas import fused_moe assert not use_grouped_topk assert num_expert_group is None assert topk_group is None + assert custom_routing_function is None return fused_moe(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -172,6 +184,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, tp_size: Optional[int] = None, prefix: str = "", + custom_routing_function: Optional[Callable] = None, ): super().__init__() @@ -190,6 +203,7 @@ def __init__( assert num_expert_group is not None and topk_group is not None self.num_expert_group = num_expert_group self.topk_group = topk_group + self.custom_routing_function = custom_routing_function if quant_config is None: self.quant_method: Optional[QuantizeMethodBase] = ( @@ -292,10 +306,28 @@ def _load_single_value(self, param: torch.nn.Parameter, # Input scales can be loaded directly and should be equal. param_data[expert_id] = loaded_weight + def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, + shard_dim: int, loaded_weight: torch.tensor, tp_rank: int): + + if shard_id == "w2": + self._load_w2(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + else: + assert shard_id in ("w1", "w3") + expert_data.copy_(loaded_weight) + def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: + # compressed-tensors represents weights on disk which are flipped + loaded_weight = loaded_weight.t().contiguous() if ( + self.quant_method.__class__.__name__ + == "CompressedTensorsMoEMethod") else loaded_weight + if shard_id not in ("w1", "w2", "w3"): raise ValueError(f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}.") @@ -311,19 +343,41 @@ def weight_loader(self, param: torch.nn.Parameter, expert_data = param.data[expert_id] tp_rank = get_tensor_model_parallel_rank() - # is_transposed: whether or not the parameter is transposed on disk - # If transposed, the loaded weight will be transposed and the dim - # to shard the loaded weight will be flipped. + # is_transposed: if the dim to shard the weight + # should be flipped. Required by GPTQ, compressed-tensors + # should be whatever dimension intermediate_size is is_transposed = getattr(param, "is_transposed", False) shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] if is_transposed: - loaded_weight = loaded_weight.t().contiguous() shard_dim = ~shard_dim - # Case weight_scales - if "weight_scale" in weight_name: - # load the weight scaling based on the quantization scheme - # supported weight scales can be found in + # Case input scale: input_scale loading is only supported for fp8 + if "input_scale" in weight_name: + if param.data[expert_id] != 1 and (param.data[expert_id] - + loaded_weight).abs() > 1e-5: + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param.data[expert_id]} " + f"vs. {loaded_weight}") + + self._load_single_value(param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + return + + # Case g_idx + if "g_idx" in weight_name: + self._load_g_idx(shard_dim=0, + shard_id=shard_id, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + return + + # Case weight scales and zero_points + if ("scale" in weight_name or "zero" in weight_name): + # load the weight scales and zp based on the quantization scheme + # supported weight scales/zp can be found in # FusedMoeWeightScaleSupported # TODO @dsikka: once hardened, refactor to use vLLM Parameters # specific to each case @@ -352,22 +406,9 @@ def weight_loader(self, param: torch.nn.Parameter, f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") return + # Case weight_shape if "weight_shape" in weight_name: - self._load_single_value(param=param, - loaded_weight=loaded_weight, - expert_id=expert_id) - return - - # Case input scale - if "input_scale" in weight_name: - # Note: input_scale loading is only supported for fp8 - if param.data[expert_id] != 1 and (param.data[expert_id] - - loaded_weight).abs() > 1e-5: - raise ValueError( - "input_scales of w1 and w3 of a layer " - f"must be equal. But got {param.data[expert_id]} " - f"vs. {loaded_weight}") - + # only required by compressed-tensors self._load_single_value(param=param, loaded_weight=loaded_weight, expert_id=expert_id) @@ -390,7 +431,8 @@ def select_experts(hidden_states: torch.Tensor, use_grouped_topk: bool, renormalize: bool, topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None): + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None): from vllm.model_executor.layers.fused_moe.fused_moe import ( fused_topk, grouped_topk) @@ -405,11 +447,17 @@ def select_experts(hidden_states: torch.Tensor, renormalize=renormalize, num_expert_group=num_expert_group, topk_group=topk_group) - else: + elif custom_routing_function is None: topk_weights, topk_ids = fused_topk(hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states=hidden_states, + gating_output=router_logits, + topk=top_k, + renormalize=renormalize) return topk_weights, topk_ids @@ -426,7 +474,8 @@ def forward(self, hidden_states: torch.Tensor, renormalize=self.renormalize, use_grouped_topk=self.use_grouped_topk, topk_group=self.topk_group, - num_expert_group=self.num_expert_group) + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function) if self.reduce_results and self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( @@ -476,4 +525,4 @@ def _load_fp8_scale(self, param: torch.nn.Parameter, param_data[expert_id][idx] = loaded_weight # If we are in the row parallel case (down_proj) else: - param_data[expert_id] = loaded_weight \ No newline at end of file + param_data[expert_id] = loaded_weight diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 1cad4e55f51ee..b997507ea738d 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -14,8 +14,10 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.parameter import (BasevLLMParameter, + PackedColumnParameter, PackedvLLMParameter, - PerTensorScaleParameter) + PerTensorScaleParameter, + RowvLLMParameter) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -23,7 +25,8 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", - "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod" + "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod", + "TPUInt8LinearMethod", "GPTQLinearMethod", "FBGEMMFp8LinearMethod" ] @@ -35,9 +38,9 @@ def adjust_marlin_shard(param, shard_size, shard_offset): return shard_size * marlin_tile_size, shard_offset * marlin_tile_size -def adjust_bitsandbytes_shard(param: Parameter, - qkv_offsets: Dict[str, Tuple[int, int]], - loaded_shard_id: str) -> Tuple[int, int]: +def adjust_bitsandbytes_4bit_shard(param: Parameter, + qkv_offsets: Dict[str, Tuple[int, int]], + loaded_shard_id: str) -> Tuple[int, int]: """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" total, _ = qkv_offsets["total"] @@ -504,8 +507,9 @@ def weight_loader(self, shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) - use_bitsandbytes = getattr(param, "use_bitsandbytes", False) - if use_bitsandbytes: + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) + if use_bitsandbytes_4bit: shard_size = loaded_weight.shape[output_dim] shard_offset = loaded_weight.shape[output_dim] * \ loaded_shard_id @@ -572,8 +576,8 @@ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, # Special case for Quantization. # If quantized, we need to adjust the offset and size to account # for the packing. - if isinstance(param, PackedvLLMParameter - ) and param.packed_dim == param.output_dim: + if isinstance(param, (PackedColumnParameter, PackedvLLMParameter + )) and param.packed_dim == param.output_dim: shard_size, shard_offset = \ param.adjust_shard_indexes_for_packing( shard_size=shard_size, shard_offset=shard_offset) @@ -592,9 +596,10 @@ def weight_loader_v2(self, param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0) return - elif type(param) is BasevLLMParameter: + elif type(param) in (RowvLLMParameter, BasevLLMParameter): param.load_merged_column_weight(loaded_weight=loaded_weight) return + # TODO: @dsikka - move to parameter.py self._load_fused_module_from_checkpoint(param, loaded_weight) return @@ -722,8 +727,8 @@ def _load_fused_module_from_checkpoint(self, param: BasevLLMParameter, # Special case for Quantization. # If quantized, we need to adjust the offset and size to account # for the packing. - if isinstance(param, PackedvLLMParameter - ) and param.packed_dim == param.output_dim: + if isinstance(param, (PackedColumnParameter, PackedvLLMParameter + )) and param.packed_dim == param.output_dim: shard_size, shard_offset = \ param.adjust_shard_indexes_for_packing( shard_size=shard_size, shard_offset=shard_offset) @@ -739,12 +744,12 @@ def weight_loader_v2(self, loaded_shard_id: Optional[str] = None): if loaded_shard_id is None: # special case for certain models if isinstance(param, PerTensorScaleParameter): - param.load_merged_column_weight(loaded_weight=loaded_weight, - shard_id=0) + param.load_qkv_weight(loaded_weight=loaded_weight, shard_id=0) return - elif type(param) is BasevLLMParameter: - param.load_merged_column_weight(loaded_weight=loaded_weight) + elif type(param) in (RowvLLMParameter, BasevLLMParameter): + param.load_qkv_weight(loaded_weight=loaded_weight) return + # TODO: @dsikka - move to parameter.py self._load_fused_module_from_checkpoint(param, loaded_weight) return @@ -857,8 +862,9 @@ def weight_loader(self, shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) - use_bitsandbytes = getattr(param, "use_bitsandbytes", False) - if use_bitsandbytes: + use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", + False) + if use_bitsandbytes_4bit: orig_qkv_offsets = { "q": (0, self.num_heads * self.head_size), "k": (self.num_heads * self.head_size, @@ -870,7 +876,7 @@ def weight_loader(self, ((self.num_heads + 2 * self.num_kv_heads) * self.head_size, 0) } - shard_size, shard_offset = adjust_bitsandbytes_shard( + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( param, orig_qkv_offsets, loaded_shard_id) if is_gguf_weight: diff --git a/vllm/model_executor/layers/mamba/__init__.py b/vllm/model_executor/layers/mamba/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/layers/mamba/ops/__init__.py b/vllm/model_executor/layers/mamba/ops/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/model_executor/layers/mamba/ops/causal_conv1d.py b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py new file mode 100644 index 0000000000000..413c8bc227ae8 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/causal_conv1d.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024, Tri Dao. + +from typing import Optional + +import torch + +from vllm import _custom_ops as ops + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + seq_idx: Optional[torch.Tensor] = None, + initial_states: Optional[torch.Tensor] = None, + return_final_states: bool = False, + final_states_out=None, + activation: str = "silu", +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert (initial_states is + None), "initial_states must be None if seq_idx is not None" + assert (not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and (initial_states.stride(2) != 1 + and initial_states.stride(1) != 1): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert (final_states_out.stride(2) == 1 + or final_states_out.stride(1) == 1) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty(batch, + width - 1, + dim, + device=x.device, + dtype=x.dtype).transpose(1, 2) + else: + final_states_out = None + + out = ops.causal_conv1d_fwd(x, weight, bias, seq_idx, initial_states, + final_states_out, activation + in ["silu", "swish"]) + return (out, None) if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = None): + """ + x: (batch, dim) + conv_state: (batch, dim, width) + weight: (dim, width) + bias: (dim,) + + out: (batch, dim) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation_bool = activation in ["silu", "swish"] + return ops.causal_conv1d_update(x, conv_state, weight, bias, + activation_bool) diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py new file mode 100644 index 0000000000000..869c69214caf2 --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -0,0 +1,346 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. + +import torch +import triton +import triton.language as tl +from packaging import version + +from vllm import _custom_ops as ops + +TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") + +if TRITON3: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) + return dt +else: + + @triton.jit + def softplus(dt): + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + return dt + + +@triton.heuristics( + {"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) +@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) +@triton.heuristics( + {"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) +@triton.jit +def _selective_scan_update_kernel( + # Pointers to matrices + state_ptr, + x_ptr, + dt_ptr, + dt_bias_ptr, + A_ptr, + B_ptr, + C_ptr, + D_ptr, + z_ptr, + out_ptr, + # Matrix dimensions + batch, + nheads, + dim, + dstate, + nheads_ngroups_ratio, + # Strides + stride_state_batch, + stride_state_head, + stride_state_dim, + stride_state_dstate, + stride_x_batch, + stride_x_head, + stride_x_dim, + stride_dt_batch, + stride_dt_head, + stride_dt_dim, + stride_dt_bias_head, + stride_dt_bias_dim, + stride_A_head, + stride_A_dim, + stride_A_dstate, + stride_B_batch, + stride_B_group, + stride_B_dstate, + stride_C_batch, + stride_C_group, + stride_C_dstate, + stride_D_head, + stride_D_dim, + stride_z_batch, + stride_z_head, + stride_z_dim, + stride_out_batch, + stride_out_head, + stride_out_dim, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + TIE_HDIM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + HAS_D: tl.constexpr, + HAS_Z: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head + if HAS_DT_BIAS: + dt_bias_ptr += pid_h * stride_dt_bias_head + A_ptr += pid_h * stride_A_head + B_ptr += pid_b * stride_B_batch + (pid_h // + nheads_ngroups_ratio) * stride_B_group + C_ptr += pid_b * stride_C_batch + (pid_h // + nheads_ngroups_ratio) * stride_C_group + if HAS_Z: + z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) + state_ptrs = state_ptr + (offs_m[:, None] * stride_state_dim + + offs_n[None, :] * stride_state_dstate) + x_ptrs = x_ptr + offs_m * stride_x_dim + dt_ptrs = dt_ptr + offs_m * stride_dt_dim + if HAS_DT_BIAS: + dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim + if HAS_D: + D_ptr += pid_h * stride_D_head + A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + + offs_n[None, :] * stride_A_dstate) + B_ptrs = B_ptr + offs_n * stride_B_dstate + C_ptrs = C_ptr + offs_n * stride_C_dstate + if HAS_D: + D_ptrs = D_ptr + offs_m * stride_D_dim + if HAS_Z: + z_ptrs = z_ptr + offs_m * stride_z_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + + state = tl.load(state_ptrs, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), + other=0.0) + x = tl.load(x_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if not TIE_HDIM: + dt = tl.load(dt_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, + other=0.0).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load(A_ptrs, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), + other=0.0).to(tl.float32) + dA = tl.exp(A * dt[:, None]) + else: + dt = tl.load(dt_ptr).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptr).to(tl.float32) + if DT_SOFTPLUS: + dt = softplus(dt) + A = tl.load(A_ptr).to(tl.float32) + dA = tl.exp(A * dt) # scalar, not a matrix + + B = tl.load(B_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + C = tl.load(C_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + if HAS_D: + D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_Z: + z = tl.load(z_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + dB = B[None, :] * dt[:, None] if not TIE_HDIM else B * dt + state = state * dA + dB * x[:, None] + tl.store(state_ptrs, + state, + mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) + out = tl.sum(state * C[None, :], axis=1) + if HAS_D: + out += x * D + if HAS_Z: + out *= z * tl.sigmoid(z) + tl.store(out_ptrs, out, mask=offs_m < dim) + + +def selective_state_update(state, + x, + dt, + A, + B, + C, + D=None, + z=None, + dt_bias=None, + dt_softplus=False): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim) or (batch, nheads, dim) + dt: (batch, dim) or (batch, nheads, dim) + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate) or (batch, ngroups, dstate) + C: (batch, dstate) or (batch, ngroups, dstate) + D: (dim,) or (nheads, dim) + z: (batch, dim) or (batch, nheads, dim) + dt_bias: (dim,) or (nheads, dim) + Return: + out: (batch, dim) or (batch, nheads, dim) + """ + has_heads = state.dim() > 3 + if state.dim() == 3: + state = state.unsqueeze(1) + if x.dim() == 2: + x = x.unsqueeze(1) + if dt.dim() == 2: + dt = dt.unsqueeze(1) + if A.dim() == 2: + A = A.unsqueeze(0) + if B.dim() == 2: + B = B.unsqueeze(1) + if C.dim() == 2: + C = C.unsqueeze(1) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if z is not None and z.dim() == 2: + z = z.unsqueeze(1) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + batch, nheads, dim, dstate = state.shape + assert x.shape == (batch, nheads, dim) + assert dt.shape == x.shape + assert A.shape == (nheads, dim, dstate) + ngroups = B.shape[1] + assert nheads % ngroups == 0, "nheads must be divisible by ngroups" + assert B.shape == (batch, ngroups, dstate) + assert C.shape == B.shape + if D is not None: + assert D.shape == (nheads, dim) + if z is not None: + assert z.shape == x.shape + if dt_bias is not None: + assert dt_bias.shape == (nheads, dim) + out = torch.empty_like(x) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) + z_strides = ((z.stride(0), z.stride(1), z.stride(2)) if z is not None else + (0, 0, 0)) + # We don't want autotune since it will overwrite the state + # We instead tune by hand. + BLOCK_SIZE_M, num_warps = ((32, 4) if dstate <= 16 else + ((16, 4) if dstate <= 32 else + ((8, 4) if dstate <= 64 else + ((4, 4) if dstate <= 128 else ((4, 8)))))) + tie_hdim = A.stride(-1) == 0 and A.stride(-2) == 0 and dt.stride( + -1) == 0 and dt_bias.stride(-1) == 0 + with torch.cuda.device(x.device.index): + _selective_scan_update_kernel[grid]( + state, + x, + dt, + dt_bias, + A, + B, + C, + D, + z, + out, + batch, + nheads, + dim, + dstate, + nheads // ngroups, + state.stride(0), + state.stride(1), + state.stride(2), + state.stride(3), + x.stride(0), + x.stride(1), + x.stride(2), + dt.stride(0), + dt.stride(1), + dt.stride(2), + *(dt_bias.stride(0), + dt_bias.stride(1)) if dt_bias is not None else 0, + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + C.stride(0), + C.stride(1), + C.stride(2), + *(D.stride(0), D.stride(1)) if D is not None else 0, + z_strides[0], + z_strides[1], + z_strides[2], + out.stride(0), + out.stride(1), + out.stride(2), + dt_softplus, + tie_hdim, + BLOCK_SIZE_M, + num_warps=num_warps, + ) + if not has_heads: + out = out.squeeze(1) + return out + + +def selective_scan_fn(u, + delta, + A, + B, + C, + D=None, + z=None, + delta_bias=None, + delta_softplus=False, + return_last_state=False, + position_indices=None, + prev_state=None): + """if return_last_state is True, returns (out, last_state) + last_state has shape (batch, dim, dstate). + """ + if u.stride(-1) != 1: + u = u.contiguous() + if delta.stride(-1) != 1: + delta = delta.contiguous() + if D is not None: + D = D.contiguous() + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if z is not None and z.stride(-1) != 1: + z = z.contiguous() + if B.dim() == 3: + B = B.unsqueeze(1) + if C.dim() == 3: + C = C.unsqueeze(1) + n_chunks = int((u.shape[-1] + 2048 - 1) / 2048) + x = torch.zeros(( + u.shape[0], + u.shape[1], + n_chunks, + int(A.shape[1] * 2), + ), + device=u.device, + dtype=torch.float32, + requires_grad=False) + x[:, :, 0, 0::2] = 1 + if prev_state is not None: + x[:, :, 0, 1::2].copy_(prev_state) + out, x, *rest = ops.selective_scan_fwd(u, delta, A, B, C, D, z, delta_bias, + delta_softplus, position_indices, x) + last_state = x[:, :, -1, 1::2] # (batch, dim, dstate) + if z is None: + return out if not return_last_state else (out, last_state) + else: + out_z = rest[0] + return out_z if not return_last_state else (out_z, last_state) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index 95b160f4287f9..aa5c288962d91 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -22,8 +22,9 @@ from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( GPTQMarlin24Config) from vllm.model_executor.layers.quantization.marlin import MarlinConfig +from vllm.model_executor.layers.quantization.neuron_quant import ( + NeuronQuantConfig) from vllm.model_executor.layers.quantization.qqq import QQQConfig -from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { @@ -41,11 +42,11 @@ "gptq_marlin": GPTQMarlinConfig, "awq_marlin": AWQMarlinConfig, "gptq": GPTQConfig, - "squeezellm": SqueezeLLMConfig, "compressed-tensors": CompressedTensorsConfig, "bitsandbytes": BitsAndBytesConfig, "qqq": QQQConfig, "experts_int8": ExpertsInt8Config, + "neuron_quant": NeuronQuantConfig, } diff --git a/vllm/model_executor/layers/quantization/awq_triton.py b/vllm/model_executor/layers/quantization/awq_triton.py new file mode 100644 index 0000000000000..d0b210c3a2747 --- /dev/null +++ b/vllm/model_executor/layers/quantization/awq_triton.py @@ -0,0 +1,316 @@ +import torch +import triton +import triton.language as tl + +AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + +@triton.jit +def awq_dequantize_kernel( + qweight_ptr, # quantized matrix + scales_ptr, # scales, per group + zeros_ptr, # zeros, per group + group_size, # Should always be one of the supported group sizes + result_ptr, # Output matrix + num_cols, # input num cols in qweight + num_rows, # input num rows in qweight + BLOCK_SIZE_X: tl.constexpr, + BLOCK_SIZE_Y: tl.constexpr): + # Setup the pids. + pid_x = tl.program_id(axis=0) + pid_y = tl.program_id(axis=1) + + # Compute offsets and masks for qweight_ptr. + offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) + offsets = num_cols * offsets_y[:, None] + offsets_x[None, :] + + masks_y = offsets_y < num_rows + masks_x = offsets_x < num_cols + + masks = masks_y[:, None] & masks_x[None, :] + + # Compute offsets and masks for result output ptr. + result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y) + result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange( + 0, BLOCK_SIZE_X * 8) + result_offsets = (8 * num_cols * result_offsets_y[:, None] + + result_offsets_x[None, :]) + + result_masks_y = result_offsets_y < num_rows + result_masks_x = result_offsets_x < num_cols * 8 + result_masks = result_masks_y[:, None] & result_masks_x[None, :] + + # Load the weights. + iweights = tl.load(qweight_ptr + offsets, masks) + iweights = tl.interleave(iweights, iweights) + iweights = tl.interleave(iweights, iweights) + iweights = tl.interleave(iweights, iweights) + + # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] + # that will map given indices to the correct order. + reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] + + tl.arange(0, 4)[:, None]).reshape(8) + + # Use this to compute a set of shifts that can be used to unpack and + # reorder the values in iweights and zeros. + shifts = reverse_awq_order_tensor * 4 + shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Unpack and reorder: shift out the correct 4-bit value and mask. + iweights = (iweights >> shifts) & 0xF + + # Compute zero offsets and masks. + zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) + zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X) + zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :] + + zero_masks_y = zero_offsets_y < num_rows // group_size + zero_masks_x = zero_offsets_x < num_cols + zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :] + + # Load the zeros. + zeros = tl.load(zeros_ptr + zero_offsets, zero_masks) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Unpack and reorder: shift out the correct 4-bit value and mask. + zeros = (zeros >> shifts) & 0xF + + # Compute scale offsets and masks. + scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1) + scale_offsets_x = (pid_x * BLOCK_SIZE_X * 8 + + tl.arange(0, BLOCK_SIZE_X * 8)) + scale_offsets = (num_cols * 8 * scale_offsets_y[:, None] + + scale_offsets_x[None, :]) + scale_masks_y = scale_offsets_y < num_rows // group_size + scale_masks_x = scale_offsets_x < num_cols * 8 + scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :] + + # Load the scales. + scales = tl.load(scales_ptr + scale_offsets, scale_masks) + scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8)) + + # Dequantize. + iweights = (iweights - zeros) * scales + iweights = iweights.to(result_ptr.type.element_ty) + + # Finally, store. + tl.store(result_ptr + result_offsets, iweights, result_masks) + + +@triton.jit +def awq_gemm_kernel(a_ptr, b_ptr, c_ptr, zeros_ptr, scales_ptr, M, N, K, + group_size, BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, + SPLIT_K: tl.constexpr): + pid = tl.program_id(axis=0) + pid_z = tl.program_id(1) + + # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. + # num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + pid_m = pid // num_pid_n + pid_n = pid % num_pid_n + + accumulator_dtype = c_ptr.type.element_ty + + # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead. + # accumulator = tl.arange(0, BLOCK_SIZE_N) + # accumulator = tl.broadcast_to(accumulator[None, :], + # (BLOCK_SIZE_M, BLOCK_SIZE_N)) + # accumulator = accumulator & 0x0 + # accumulator = accumulator.to(accumulator_dtype) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), + dtype=accumulator_dtype) + + # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7] + # that will map given indices to the correct order. + reverse_awq_order_tensor = ((tl.arange(0, 2) * 4)[None, :] + + tl.arange(0, 4)[:, None]).reshape(8) + + # Create the necessary shifts to use to unpack. + shifts = reverse_awq_order_tensor * 4 + shifts = tl.broadcast_to(shifts[None, :], + (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8)) + shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + # Offsets and masks. + offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + masks_am = offsets_am < M + + offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8) + masks_bn = offsets_bn < N // 8 + + offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8) + masks_zn = offsets_zn < N // 8 + + offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + masks_sn = offsets_sn < N + + offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offsets_a = K * offsets_am[:, None] + offsets_k[None, :] + offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :] + + a_ptrs = a_ptr + offsets_a + b_ptrs = b_ptr + offsets_b + + # NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv + # block_offset = BLOCK_SIZE_K * SPLIT_K + # for k in range(0, (K + block_offset - 1) // (block_offset)): + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)): + masks_k = offsets_k < K + masks_a = masks_am[:, None] & masks_k[None, :] + a = tl.load(a_ptrs, mask=masks_a) + + masks_b = masks_k[:, None] & masks_bn[None, :] + b = tl.load(b_ptrs, mask=masks_b) + b = tl.interleave(b, b) + b = tl.interleave(b, b) + b = tl.interleave(b, b) + + # Dequantize b. + offsets_szk = ( + (BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K) // group_size + + tl.arange(0, 1)) + offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :] + masks_zk = offsets_szk < K // group_size + masks_z = masks_zk[:, None] & masks_zn[None, :] + zeros_ptrs = zeros_ptr + offsets_z + zeros = tl.load(zeros_ptrs, mask=masks_z) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.interleave(zeros, zeros) + zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :] + masks_sk = offsets_szk < K // group_size + masks_s = masks_sk[:, None] & masks_sn[None, :] + scales_ptrs = scales_ptr + offsets_s + scales = tl.load(scales_ptrs, mask=masks_s) + scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N)) + + b = (b >> shifts) & 0xF + zeros = (zeros >> shifts) & 0xF + b = (b - zeros) * scales + b = b.to(c_ptr.type.element_ty) + + # Accumulate results. + accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype) + + offsets_k += BLOCK_SIZE_K * SPLIT_K + a_ptrs += BLOCK_SIZE_K * SPLIT_K + b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8) + + c = accumulator.to(c_ptr.type.element_ty) + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + N * offs_cm[:, None] + offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + if SPLIT_K == 1: + tl.store(c_ptrs, c, mask=c_mask) + else: + tl.atomic_add(c_ptrs, c, mask=c_mask) + + +# qweights - [K , M // 8], int32 +# scales - [K // G, M ], float16 +# zeros - [K // G, M // 8], int32 +def awq_dequantize_triton(qweight: torch.Tensor, + scales: torch.Tensor, + zeros: torch.Tensor, + block_size_x: int = 32, + block_size_y: int = 32) -> torch.Tensor: + K = qweight.shape[0] + M = scales.shape[1] + group_size = qweight.shape[0] // scales.shape[0] + + assert K > 0 and M > 0 + assert scales.shape[0] == K // group_size and scales.shape[1] == M + assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8 + assert group_size <= K + assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K + + # Result tensor: + # number of rows = same as input tensor + # number of cols = 8 x input tensor num cols + result = torch.empty(qweight.shape[0], + qweight.shape[1] * 8, + device=qweight.device, + dtype=scales.dtype) + + Y = qweight.shape[0] # num rows + X = qweight.shape[1] # num cols + + grid = lambda META: ( + triton.cdiv(X, META['BLOCK_SIZE_X']), + triton.cdiv(Y, META['BLOCK_SIZE_Y']), + ) + awq_dequantize_kernel[grid](qweight, + scales, + zeros, + group_size, + result, + X, + Y, + BLOCK_SIZE_X=block_size_x, + BLOCK_SIZE_Y=block_size_y) + + return result + + +# input - [M, K] +# qweight - [K, N // 8] +# qzeros - [K // G, N // 8] +# scales - [K // G, N] +# split_k_iters - parallelism along K-dimension, int, power of 2. +def awq_gemm_triton(input: torch.Tensor, + qweight: torch.Tensor, + scales: torch.Tensor, + qzeros: torch.Tensor, + split_k_iters: int, + block_size_m: int = 32, + block_size_n: int = 32, + block_size_k: int = 32) -> torch.Tensor: + M, K = input.shape + N = qweight.shape[1] * 8 + group_size = qweight.shape[0] // qzeros.shape[0] + + assert N > 0 and K > 0 and M > 0 + assert qweight.shape[0] == K and qweight.shape[1] == N // 8 + assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8 + assert scales.shape[0] == K // group_size and scales.shape[1] == N + assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0 + assert split_k_iters <= 32 + assert group_size <= K + assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv( + N, META['BLOCK_SIZE_N']), + split_k_iters, + ) + + result = torch.zeros((M, N), dtype=scales.dtype, device=input.device) + + # A = input, B = qweight, C = result + # A = M x K, B = K x N, C = M x N + awq_gemm_kernel[grid](input, + qweight, + result, + qzeros, + scales, + M, + N, + K, + group_size, + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + SPLIT_K=split_k_iters) + + return result diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index c143d1a8f2bc7..66bc5395dbd7a 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -1,7 +1,6 @@ from typing import Any, Dict, List, Optional import torch -from torch.nn.parameter import Parameter from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, set_weight_attrs) @@ -15,8 +14,28 @@ class BitsAndBytesConfig(QuantizationConfig): Reference: https://arxiv.org/abs/2305.14314 """ - def __init__(self, ) -> None: - pass + def __init__( + self, + load_in_8bit: bool = False, + load_in_4bit: bool = True, + bnb_4bit_compute_dtype: str = "float32", + bnb_4bit_quant_type: str = "fp4", + bnb_4bit_use_double_quant: bool = False, + llm_int8_enable_fp32_cpu_offload: bool = False, + llm_int8_has_fp16_weight: bool = False, + llm_int8_skip_modules: Optional[Any] = None, + llm_int8_threshold: float = 0.0, + ) -> None: + + self.load_in_8bit = load_in_8bit + self.load_in_4bit = load_in_4bit + self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype + self.bnb_4bit_quant_type = bnb_4bit_quant_type + self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant + self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload + self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight + self.llm_int8_skip_modules = llm_int8_skip_modules + self.llm_int8_threshold = llm_int8_threshold def __repr__(self) -> str: return "BitsAndBytesConfig" @@ -41,7 +60,46 @@ def get_config_filenames() -> List[str]: @classmethod def from_config(cls, config: Dict[str, Any]) -> "BitsAndBytesConfig": - return cls() + + def get_safe_value(config, keys, default_value=None): + try: + value = cls.get_from_keys(config, keys) + return value if value is not None else default_value + except ValueError: + return default_value + + load_in_8bit = get_safe_value(config, ["load_in_8bit"], + default_value=False) + load_in_4bit = get_safe_value(config, ["load_in_4bit"], + default_value=True) + bnb_4bit_compute_dtype = get_safe_value(config, + ["bnb_4bit_compute_dtype"], + default_value="float32") + bnb_4bit_quant_type = get_safe_value(config, ["bnb_4bit_quant_type"], + default_value="fp4") + bnb_4bit_use_double_quant = get_safe_value( + config, ["bnb_4bit_use_double_quant"], default_value=False) + llm_int8_enable_fp32_cpu_offload = get_safe_value( + config, ["llm_int8_enable_fp32_cpu_offload"], default_value=False) + llm_int8_has_fp16_weight = get_safe_value(config, + ["llm_int8_has_fp16_weight"], + default_value=False) + llm_int8_skip_modules = get_safe_value(config, + ["llm_int8_skip_modules"], + default_value=[]) + llm_int8_threshold = get_safe_value(config, ["llm_int8_threshold"], + default_value=0.0) + + return cls( + load_in_8bit=load_in_8bit, + load_in_4bit=load_in_4bit, + bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, + bnb_4bit_quant_type=bnb_4bit_quant_type, + bnb_4bit_use_double_quant=bnb_4bit_use_double_quant, + llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload, + llm_int8_has_fp16_weight=llm_int8_has_fp16_weight, + llm_int8_skip_modules=llm_int8_skip_modules, + llm_int8_threshold=llm_int8_threshold) def get_quant_method(self, layer: torch.nn.Module, prefix: str) -> Optional["BitsAndBytesLinearMethod"]: @@ -78,39 +136,58 @@ def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - quant_ratio = 0 - if params_dtype.is_floating_point: - quant_ratio = torch.finfo(params_dtype).bits // torch.iinfo( - torch.uint8).bits + from bitsandbytes.nn import Int8Params + + def calculate_quant_ratio(dtype): + if dtype.is_floating_point: + return torch.finfo(dtype).bits // torch.iinfo(torch.uint8).bits + else: + return torch.iinfo(dtype).bits // torch.iinfo(torch.uint8).bits + + def create_qweight_for_8bit(): + qweight = Int8Params( + data=torch.empty(sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.int8), + has_fp16_weights=self.quant_config.llm_int8_has_fp16_weight, + requires_grad=False) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 0, + "pack_factor": 1, + "use_bitsandbytes_8bit": True, + "generation": 0 + }) + return qweight + + def create_qweight_for_4bit(): + quant_ratio = calculate_quant_ratio(params_dtype) + + total_size = input_size_per_partition * sum(output_partition_sizes) + if total_size % quant_ratio != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape.") + + qweight = torch.nn.Parameter(torch.empty(total_size // quant_ratio, + 1, + dtype=torch.uint8), + requires_grad=False) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 0, + "pack_factor": quant_ratio, + "use_bitsandbytes_4bit": True + }) + return qweight + + if self.quant_config.load_in_8bit: + qweight = create_qweight_for_8bit() else: - quant_ratio = torch.iinfo(params_dtype).bits // torch.iinfo( - torch.uint8).bits - - if input_size_per_partition * sum( - output_partition_sizes) % quant_ratio != 0: - raise ValueError( - "The input size is not aligned with the quantized " - "weight shape. ") - qweight = Parameter( - torch.empty( - input_size_per_partition * sum(output_partition_sizes) // - quant_ratio, - 1, - dtype=torch.uint8, - ), - requires_grad=False, - ) - - set_weight_attrs( - qweight, - { - "input_dim": 0, - # In bitsandbytes, a tensor of shape [n,m] is quantized to - #[n*m/pack_ratio, 1],so the output_dim is 0 - "output_dim": 0, - "pack_factor": quant_ratio, - "use_bitsandbytes": True, - }) + qweight = create_qweight_for_4bit() + layer.register_parameter("qweight", qweight) set_weight_attrs(qweight, extra_weight_attrs) @@ -119,6 +196,88 @@ def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.quant_config.load_in_8bit: + return self._apply_8bit_weight(layer, x, bias) + else: + return self._apply_4bit_weight(layer, x, bias) + + def _apply_8bit_weight( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + # only load the bitsandbytes module when needed + from bitsandbytes import MatmulLtState, matmul + + original_type = x.dtype + bf_x = x.to(torch.bfloat16) + + qweight = layer.qweight + offsets = qweight.bnb_shard_offsets + quant_states = qweight.bnb_quant_state + matmul_states = qweight.matmul_state + generation = qweight.generation + + out_dim_0 = x.shape[0] + out_dim_1 = sum( + [quant_state[1].shape[0] for quant_state in quant_states.items()]) + out = torch.empty(out_dim_0, + out_dim_1, + dtype=torch.float16, + device=x.device) + + current_index = 0 + for i in range(len(quant_states)): + output_size = quant_states[i].shape[0] + + # in profile_run or the first generation of inference, + # create new matmul_states + if generation == 0 or generation == 1: + matmul_states[i] = MatmulLtState() + matmul_states[i].CB = qweight[offsets[i]:offsets[i + 1]] + matmul_states[i].SCB = quant_states[i] + matmul_states[i].threshold = ( + self.quant_config.llm_int8_threshold) + matmul_states[i].has_fp16_weights = ( + self.quant_config.llm_int8_has_fp16_weight) + matmul_states[i].is_training = False + if matmul_states[i].threshold > 0.0 and not matmul_states[ + i].has_fp16_weights: + matmul_states[i].use_pool = True + + new_x = bf_x.unsqueeze(0) + + out[:, current_index:current_index + output_size] = matmul( + new_x, + qweight[offsets[i]:offsets[i + 1]], + state=matmul_states[i]) + + current_index += output_size + + # only update the matmul_states if it is not profile_run + if (generation > 0 + and not self.quant_config.llm_int8_has_fp16_weight + and matmul_states[i].CB is not None + and matmul_states[i].CxB is not None): + del matmul_states[i].CB + qweight[offsets[i]:offsets[i + 1]] = matmul_states[i].CxB + + out = out.to(original_type) + + if bias is not None: + out += bias + + qweight.generation += 1 + + return out + + def _apply_4bit_weight( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + # only load the bitsandbytes module when needed from bitsandbytes import matmul_4bit diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 0768b37044aac..1170d55f31993 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -232,7 +232,8 @@ def _get_scheme_from_parts( return CompressedTensorsWNA16( num_bits=weight_quant.num_bits, strategy=weight_quant.strategy, - group_size=weight_quant.group_size) + group_size=weight_quant.group_size, + actorder=weight_quant.actorder) # Detect If Activation Quantization. # TODO @dsikka: clean-up conditions diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 0e0ab9ce9169f..49c29c2775cb6 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1,13 +1,11 @@ import enum from enum import Enum -from typing import List, Optional +from typing import Callable, List, Optional import torch from vllm import _custom_ops as ops -from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - WNA16_SUPPORTED_BITS) +from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( CompressionFormat) from vllm.model_executor.utils import set_weight_attrs @@ -40,11 +38,10 @@ def __init__( if not (self.quant_config.quant_format == CompressionFormat.pack_quantized.value - and self.num_bits in WNA16_SUPPORTED_BITS): + and self.num_bits == 4): raise ValueError("For Fused MoE layers, only ", f"{CompressionFormat.pack_quantized.value} ", - "is supported for the following bits: ", - f"{WNA16_SUPPORTED_BITS}") + "is supported for 4 bits") def create_weights(self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, @@ -256,28 +253,43 @@ def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, ) replace_tensor("w2_weight_scale", marlin_w2_scales) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: - - from vllm.model_executor.layers.fused_moe.fused_moe import ( + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( fused_marlin_moe) - return fused_marlin_moe(x, - layer.w13_weight_packed, - layer.w2_weight_packed, - router_logits, - layer.w13_g_idx, - layer.w2_g_idx, - layer.w13_g_idx_sort_indices, - layer.w2_g_idx_sort_indices, - top_k, - renormalize=renormalize, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale) + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) + + return fused_marlin_moe( + x, + layer.w13_weight_packed, + layer.w2_weight_packed, + router_logits, + layer.w13_g_idx, + layer.w2_g_idx, + layer.w13_g_idx_sort_indices, + layer.w2_g_idx_sort_indices, + topk_weights, + topk_ids, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py index 7ca8eecb9283e..3cade3d3fbcd0 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py @@ -5,20 +5,24 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsScheme) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + ActivationOrdering) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_gptq_marlin_linear, marlin_make_empty_g_idx, marlin_make_workspace, - marlin_permute_scales, replace_tensor, verify_marlin_supported, + marlin_permute_scales, marlin_repeat_scales_on_all_ranks, + marlin_sort_g_idx, replace_tensor, verify_marlin_supported, verify_marlin_supports_shape) from vllm.model_executor.parameter import (BasevLLMParameter, ChannelQuantScaleParameter, GroupQuantScaleParameter, - PackedvLLMParameter) + PackedvLLMParameter, + RowvLLMParameter) from vllm.scalar_type import scalar_types __all__ = ["CompressedTensorsWNA16"] WNA16_SUPPORTED_TYPES_MAP = { 4: scalar_types.uint4b8, - 8: scalar_types.uint8b128, + 8: scalar_types.uint8b128 } WNA16_SUPPORTED_BITS = list(WNA16_SUPPORTED_TYPES_MAP.keys()) @@ -28,11 +32,13 @@ class CompressedTensorsWNA16(CompressedTensorsScheme): def __init__(self, strategy: str, num_bits: int, - group_size: Optional[int] = None): + group_size: Optional[int] = None, + actorder: Optional[ActivationOrdering] = None): self.pack_factor = 32 // num_bits self.strategy = strategy self.group_size = -1 if group_size is None else group_size + self.has_g_idx = actorder == ActivationOrdering.GROUP if self.group_size == -1 and self.strategy != "channel": raise ValueError("Marlin kernels require group quantization or " @@ -64,12 +70,10 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, output_size_per_partition = sum(output_partition_sizes) # If group_size is -1, we are in channelwise case. - channelwise = (self.group_size == -1) group_size = self.group_size if self.group_size != -1 else input_size row_parallel = (input_size != input_size_per_partition) - # In the case of channelwise quantization, we need to replicate the - # scales across all gpus. - partition_scales = (row_parallel and not channelwise) + partition_scales = not marlin_repeat_scales_on_all_ranks( + self.has_g_idx, self.group_size, row_parallel) verify_marlin_supports_shape( output_size_per_partition=output_size_per_partition, @@ -123,6 +127,16 @@ def create_weights(self, layer: torch.nn.Module, input_size: int, layer.register_parameter("weight_scale", weight_scale) layer.register_parameter("weight_shape", weight_shape) + # group index (for activation reordering) + if self.has_g_idx: + weight_g_idx = RowvLLMParameter(data=torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight_g_idx", weight_g_idx) + layer.input_size_per_partition = input_size_per_partition layer.output_size_per_partition = output_size_per_partition layer.input_size = input_size @@ -137,9 +151,14 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.workspace = marlin_make_workspace( layer.output_size_per_partition, device) - # Act-order not supported in compressed-tensors yet, so set to empty. - layer.g_idx = marlin_make_empty_g_idx(device) - layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + # Handle sorting for activation reordering if needed. + if self.has_g_idx: + g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.weight_g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + replace_tensor(layer, "weight_g_idx", g_idx) + else: + layer.weight_g_idx = marlin_make_empty_g_idx(device) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) # No zero-point layer.weight_zp = marlin_make_empty_g_idx(device) @@ -159,9 +178,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: replace_tensor(layer, "weight_packed", marlin_qweight) # Permute scales from compressed-tensors format to marlin format. + # scale is required on all partitions if activation reordering marlin_scales = marlin_permute_scales( layer.weight_scale, - size_k=layer.input_size_per_partition, + size_k=(layer.input_size + if self.has_g_idx else layer.input_size_per_partition), size_n=layer.output_size_per_partition, group_size=layer.group_size) replace_tensor(layer, "weight_scale", marlin_scales) @@ -174,7 +195,7 @@ def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, weight=layer.weight_packed, weight_scale=layer.weight_scale, weight_zp=layer.weight_zp, - g_idx=layer.g_idx, + g_idx=layer.weight_g_idx, g_idx_sort_indices=layer.g_idx_sort_indices, workspace=layer.workspace, wtype=self.quant_type, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py index 7912cbde5721f..fc531b9d666e3 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/utils.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/utils.py @@ -1,8 +1,8 @@ import re from enum import Enum -from typing import Any, Dict, Iterable, Optional +from typing import Any, Dict, Iterable, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from torch.nn import Module from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -40,6 +40,19 @@ class QuantizationStrategy(str, Enum): TOKEN = "token" +class ActivationOrdering(str, Enum): + """ + Enum storing strategies for activation ordering + + Group: reorder groups and weight\n + Weight: only reorder weight, not groups. Slightly lower latency and + accuracy compared to group actorder\n + """ + + GROUP = "group" + WEIGHT = "weight" + + class QuantizationArgs(BaseModel): """ User facing arguments used to define a quantization config @@ -58,6 +71,8 @@ class QuantizationArgs(BaseModel): observed with every sample. Defaults to False for static quantization. Note that enabling dynamic quantization will change the default observer to a memoryless one + :param actorder: whether to apply group quantization in decreasing order of + activation. Defaults to None for arbitrary ordering """ num_bits: int = 8 @@ -67,6 +82,7 @@ class QuantizationArgs(BaseModel): strategy: Optional[QuantizationStrategy] = None block_structure: Optional[str] = None dynamic: bool = False + actorder: Union[ActivationOrdering, bool, None] = None observer: str = Field( default="minmax", description=("The class to use to compute the quantization param - " @@ -79,6 +95,16 @@ class QuantizationArgs(BaseModel): "Observers constructor excluding quantization range or symmetry"), ) + @field_validator("actorder", mode="before") + def validate_actorder(cls, value) -> Optional[ActivationOrdering]: + if isinstance(value, bool): + return ActivationOrdering.GROUP if value else None + + if isinstance(value, str): + return ActivationOrdering(value.lower()) + + return value + def is_activation_quantization_format(format: str) -> bool: _ACTIVATION_QUANTIZATION_FORMATS = [ diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index dabf17df78fef..116a4ea0aed89 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch @@ -96,15 +96,18 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int, requires_grad=False) layer.register_parameter("w2_scale", w2_scale) - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool = True, - use_grouped_topk: bool = False, - num_expert_group: Optional[int] = None, - topk_group: Optional[int] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts topk_weights, topk_ids = FusedMoE.select_experts( @@ -114,7 +117,8 @@ def apply(self, top_k=top_k, renormalize=renormalize, topk_group=topk_group, - num_expert_group=num_expert_group) + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) return fused_experts(x, layer.w13_weight, diff --git a/vllm/model_executor/layers/quantization/fbgemm_fp8.py b/vllm/model_executor/layers/quantization/fbgemm_fp8.py index e7c3859967c71..3ccf1af9eb898 100644 --- a/vllm/model_executor/layers/quantization/fbgemm_fp8.py +++ b/vllm/model_executor/layers/quantization/fbgemm_fp8.py @@ -15,8 +15,9 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( - apply_fp8_linear, create_per_channel_scale_param) -from vllm.model_executor.utils import set_weight_attrs + apply_fp8_linear) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + ModelWeightParameter) from vllm.platforms import current_platform logger = init_logger(__name__) @@ -85,6 +86,7 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): + weight_loader = extra_weight_attrs.get("weight_loader") del input_size, output_size output_size_per_partition = sum(output_partition_sizes) @@ -95,20 +97,21 @@ def create_weights( layer.orig_dtype = params_dtype # WEIGHT - weight = Parameter(torch.empty(output_size_per_partition, - input_size_per_partition, - dtype=torch.float8_e4m3fn), - requires_grad=False) + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) layer.register_parameter("weight", weight) - set_weight_attrs(weight, { - "input_dim": 1, - "output_dim": 0, - **extra_weight_attrs, - }) # WEIGHT SCALE - weight_scale = create_per_channel_scale_param(output_partition_sizes, - **extra_weight_attrs) + weight_scale = ChannelQuantScaleParameter(data=torch.empty( + (sum(output_partition_sizes), 1), dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + weight_scale[:] = torch.finfo(torch.float32).min layer.register_parameter("weight_scale", weight_scale) # INPUT SCALE UPPER BOUND @@ -118,6 +121,11 @@ def create_weights( layer.input_scale_ub = input_scale_ub def process_weights_after_loading(self, layer: Module) -> None: + # required by torch.compile + layer.weight_scale = Parameter(layer.weight_scale.data, + requires_grad=False) + layer.weight = Parameter(layer.weight.data, requires_grad=False) + weight = layer.weight layer.weight = Parameter(weight.t(), requires_grad=False) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 1817dbcb023a7..32affe06b89b7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional import torch from torch.nn import Module @@ -468,15 +468,18 @@ def process_weights_after_loading(self, layer: Module) -> None: requires_grad=False) return - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None) -> torch.Tensor: + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts @@ -487,7 +490,8 @@ def apply(self, top_k=top_k, renormalize=renormalize, topk_group=topk_group, - num_expert_group=num_expert_group) + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function) return fused_experts(x, layer.w13_weight, diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index f456286899a53..c067a76405df6 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -11,7 +11,11 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedColumnParameter, + PackedvLLMParameter, + RowvLLMParameter) class GPTQConfig(QuantizationConfig): @@ -108,6 +112,7 @@ def create_weights( **extra_weight_attrs, ): del output_size # Unused. + weight_loader = extra_weight_attrs.get("weight_loader") if input_size_per_partition % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " @@ -138,73 +143,81 @@ def create_weights( scale_and_zero_size = input_size_per_partition // group_size scale_and_zero_input_dim = 0 - qweight = Parameter( - torch.empty( + qweight = PackedvLLMParameter( + data=torch.empty( input_size_per_partition // self.quant_config.pack_factor, output_size_per_partition, dtype=torch.int32, ), - requires_grad=False, - ) - set_weight_attrs( - qweight, { - "input_dim": 0, - "output_dim": 1, - "packed_dim": 0, - "pack_factor": self.quant_config.pack_factor, - }) - g_idx = Parameter( - torch.tensor( - [ - i // self.quant_config.group_size - for i in range(input_size_per_partition) - ], - dtype=torch.int32, - ), - requires_grad=False, - ) - # Ignore warning from fused linear layers such as QKVParallelLinear. - set_weight_attrs(g_idx, {"input_dim": 0, "ignore_warning": True}) - qzeros = Parameter( + input_dim=0, + output_dim=1, + packed_dim=0, + packed_factor=self.quant_config.pack_factor, + weight_loader=weight_loader) + + g_idx = RowvLLMParameter(data=torch.tensor( + [ + i // self.quant_config.group_size + for i in range(input_size_per_partition) + ], + dtype=torch.int32, + ), + input_dim=0, + weight_loader=weight_loader) + qzeros_args = { + "data": torch.empty( scale_and_zero_size, output_size_per_partition // self.quant_config.pack_factor, dtype=torch.int32, ), - requires_grad=False, - ) - set_weight_attrs( - qzeros, { - "input_dim": scale_and_zero_input_dim, - "output_dim": 1, - "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, - }) - scales = Parameter( + "weight_loader": + weight_loader + } + weight_scale_args = { + "data": torch.empty( scale_and_zero_size, output_size_per_partition, dtype=params_dtype, ), - requires_grad=False, - ) - set_weight_attrs(scales, { - "input_dim": scale_and_zero_input_dim, - "output_dim": 1, - }) + "weight_loader": + weight_loader + } + if scale_and_zero_input_dim is None: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + qzeros = PackedColumnParameter( + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) + + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) + qzeros = PackedvLLMParameter( + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + **qzeros_args) layer.register_parameter("qweight", qweight) - set_weight_attrs(qweight, extra_weight_attrs) layer.register_parameter("g_idx", g_idx) - set_weight_attrs(g_idx, extra_weight_attrs) layer.register_parameter("qzeros", qzeros) - set_weight_attrs(qzeros, extra_weight_attrs) layer.register_parameter("scales", scales) - set_weight_attrs(scales, extra_weight_attrs) layer.exllama_state = exllama_state def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # for torch.compile + layer.qweight = Parameter(layer.qweight.data, requires_grad=False) + layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False) + layer.qweight = Parameter(layer.qweight.data, requires_grad=False) + layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False) + # exllama needs to shuffle the weight after the weight is loaded # here we do the shuffle on first forward pass if layer.exllama_state == ExllamaState.UNINITIALIZED: diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 94eb3f301541a..3617a32f80fc1 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -1,18 +1,22 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch.nn import Parameter from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.utils.marlin_utils import ( apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full, - marlin_make_empty_g_idx, marlin_make_workspace, marlin_permute_scales, - marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx, replace_tensor, - verify_marlin_supported, verify_marlin_supports_shape) + marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales, + marlin_permute_scales, marlin_repeat_scales_on_all_ranks, + marlin_sort_g_idx, replace_tensor, verify_marlin_supported, + verify_marlin_supports_shape) from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import (ChannelQuantScaleParameter, GroupQuantScaleParameter, @@ -33,8 +37,14 @@ class GPTQMarlinConfig(QuantizationConfig): (8, True): scalar_types.uint8b128, } - def __init__(self, weight_bits: int, group_size: int, desc_act: bool, - is_sym: bool, lm_head_quantized: bool) -> None: + def __init__( + self, + weight_bits: int, + group_size: int, + desc_act: bool, + is_sym: bool, + lm_head_quantized: bool, + ) -> None: if desc_act and group_size == -1: # In this case, act_order == True is the same as act_order == False # (since we have only one group per output channel) @@ -51,10 +61,6 @@ def __init__(self, weight_bits: int, group_size: int, desc_act: bool, self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] - # Verify supported on platform. - verify_marlin_supported(quant_type=self.quant_type, - group_size=self.group_size) - def __repr__(self) -> str: return (f"GPTQMarlinConfig(quant_type={self.quant_type}, " f"group_size={self.group_size}, " @@ -109,11 +115,14 @@ def override_quantization_method(cls, hf_quant_cfg, " faster inference") return None - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["GPTQMarlinLinearMethod"]: - if (isinstance(layer, LinearBase) or - (isinstance(layer, ParallelLMHead) and self.lm_head_quantized)): + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[Union["GPTQMarlinLinearMethod", "GPTQMarlinMoEMethod"]]: + if isinstance(layer, LinearBase) or (isinstance(layer, ParallelLMHead) + and self.lm_head_quantized): return GPTQMarlinLinearMethod(self) + elif isinstance(layer, FusedMoE): + return GPTQMarlinMoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: @@ -153,6 +162,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase): def __init__(self, quant_config: GPTQMarlinConfig) -> None: self.quant_config = quant_config + # Verify supported on platform. + verify_marlin_supported(quant_type=self.quant_config.quant_type, + group_size=self.quant_config.group_size) + def create_weights( self, layer: torch.nn.Module, @@ -179,7 +192,8 @@ def create_weights( output_size_per_partition=output_size_per_partition, input_size_per_partition=input_size_per_partition, input_size=input_size, - group_size=group_size) + group_size=group_size, + ) # Determine sharding if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act, @@ -299,7 +313,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: perm=layer.g_idx_sort_indices, size_k=layer.input_size_per_partition, size_n=layer.output_size_per_partition, - num_bits=self.quant_config.quant_type.size_bits) + num_bits=self.quant_config.quant_type.size_bits, + ) replace_tensor(layer, "qweight", marlin_qweight) # Permute scales from autogptq format to marlin format. @@ -308,7 +323,8 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: size_k=(layer.input_size if self.quant_config.desc_act else layer.input_size_per_partition), size_n=layer.output_size_per_partition, - group_size=self.quant_config.group_size) + group_size=self.quant_config.group_size, + ) replace_tensor(layer, "scales", marlin_scales) def apply( @@ -329,4 +345,270 @@ def apply( output_size_per_partition=layer.output_size_per_partition, input_size_per_partition=layer.input_size_per_partition, is_k_full=layer.is_k_full, - bias=bias) + bias=bias, + ) + + +class GPTQMarlinMoEMethod(FusedMoEMethodBase): + """MoE Marlin method with quantization.""" + + def __init__(self, quant_config: GPTQMarlinConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + # Currently assuming is_k_full is always True + # (input size per partition is the same as full input size) + # Supports only sym for now (no zp) + if self.quant_config.group_size != -1: + scales_size13 = hidden_size // self.quant_config.group_size + scales_size2 = intermediate_size // self.quant_config.group_size + strategy = FusedMoeWeightScaleSupported.GROUP.value + else: + scales_size13 = 1 + scales_size2 = 1 + strategy = FusedMoeWeightScaleSupported.CHANNEL.value + + extra_weight_attrs.update({ + "quant_method": strategy, + "is_transposed": True + }) + # Fused gate_up_proj (column parallel) + w13_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size // self.quant_config.pack_factor, + 2 * intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_qweight", w13_qweight) + set_weight_attrs(w13_qweight, extra_weight_attrs) + # down_proj (row parallel) + w2_qweight = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size // self.quant_config.pack_factor, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_qweight", w2_qweight) + set_weight_attrs(w2_qweight, extra_weight_attrs) + # up_proj scales + w13_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size, + dtype=torch.half), + requires_grad=False, + ) + layer.register_parameter("w13_scales", w13_scales) + set_weight_attrs(w13_scales, extra_weight_attrs) + # down_proj scales + w2_scales = torch.nn.Parameter( + torch.empty(num_experts, + scales_size2, + hidden_size, + dtype=torch.half), + requires_grad=False, + ) + layer.register_parameter("w2_scales", w2_scales) + set_weight_attrs(w2_scales, extra_weight_attrs) + # up_proj scales + w13_qzeros = torch.nn.Parameter( + torch.empty(num_experts, + scales_size13, + 2 * intermediate_size // self.quant_config.pack_factor, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w13_qzeros", w13_qzeros) + set_weight_attrs(w13_qzeros, extra_weight_attrs) + # down_proj scales + w2_qzeros = torch.nn.Parameter( + torch.empty(num_experts, + scales_size2, + hidden_size // self.quant_config.pack_factor, + dtype=params_dtype), + requires_grad=False, + ) + layer.register_parameter("w2_qzeros", w2_qzeros) + set_weight_attrs(w2_qzeros, extra_weight_attrs) + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + # Process act_order + if self.quant_config.desc_act: + # Get sorting based on g_idx + num_experts = layer.w13_g_idx.shape[0] + w13_g_idx_sort_indices = torch.empty_like(layer.w13_g_idx) + w2_g_idx_sort_indices = torch.empty_like(layer.w2_g_idx) + w13_sorted_g_idx = torch.empty_like(layer.w13_g_idx) + w2_sorted_g_idx = torch.empty_like(layer.w2_g_idx) + for e in range(num_experts): + w13_g_idx_sort_indices[e] = torch.argsort( + layer.w13_g_idx[e]).to(torch.int32) + w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_g_idx[e]).to( + torch.int32) + w13_sorted_g_idx[e] = layer.w13_g_idx[e][ + w13_g_idx_sort_indices[e]] + w2_sorted_g_idx[e] = layer.w2_g_idx[e][ + w2_g_idx_sort_indices[e]] + replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx) + replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx) + replace_tensor(layer, "w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + replace_tensor(layer, "w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + else: + # Reset g_idx related tensors + num_experts = layer.w13_g_idx.shape[0] + device = layer.w13_g_idx.device + layer.w13_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, + device=device), + requires_grad=False, + ) + # Repack weights + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_qweight, + layer.w13_g_idx_sort_indices, + layer.w13_qweight.shape[1] * self.quant_config.pack_factor, + layer.w13_qweight.shape[2], + self.quant_config.quant_type.size_bits, + ) + replace_tensor(layer, "w13_qweight", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_qweight, + layer.w2_g_idx_sort_indices, + layer.w2_qweight.shape[1] * self.quant_config.pack_factor, + layer.w2_qweight.shape[2], + self.quant_config.quant_type.size_bits, + ) + replace_tensor(layer, "w2_qweight", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + s=layer.w13_scales, + size_k=layer.intermediate_size_per_partition, + size_n=layer.w13_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_tensor(layer, "w13_scales", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + s=layer.w2_scales, + size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor, + size_n=layer.w2_scales.shape[2], + group_size=self.quant_config.group_size, + ) + replace_tensor(layer, "w2_scales", marlin_w2_scales) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe.fused_marlin_moe import ( + fused_marlin_moe) + + # The input must currently be float16 + orig_dtype = x.dtype + x = x.half() + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=None) + + return fused_marlin_moe( + x, + layer.w13_qweight, + layer.w2_qweight, + router_logits, + layer.w13_g_idx, + layer.w2_g_idx, + layer.w13_g_idx_sort_indices, + layer.w2_g_idx_sort_indices, + topk_weights, + topk_ids, + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + ).to(orig_dtype) diff --git a/vllm/model_executor/layers/quantization/neuron_quant.py b/vllm/model_executor/layers/quantization/neuron_quant.py new file mode 100644 index 0000000000000..2624981f6a614 --- /dev/null +++ b/vllm/model_executor/layers/quantization/neuron_quant.py @@ -0,0 +1,67 @@ +import os +from importlib.util import find_spec +from typing import Any, Dict, List, Optional + +from torch.nn import Module + +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + +SUPPORTED_QUANT_DTYPE_LIST = ['s8', 'f8e4m3fn'] + + +class NeuronQuantConfig(QuantizationConfig): + """Int8 Quantization Config class for Neuron Backend.""" + + def __init__( + self, + dequant_dtype: str = "f16", + quantize_method: str = "vector_dynamic", + ) -> None: + self.quant_dtype = os.getenv("NEURON_QUANT_DTYPE", "s8") + if self.quant_dtype not in SUPPORTED_QUANT_DTYPE_LIST: + raise ValueError( + f"Neuron quantization datatype {self.quant_dtype} is not valid," + f"the quantization datatype should match one of the below types" + f"{SUPPORTED_QUANT_DTYPE_LIST}") + self.dequant_dtype = dequant_dtype + self.quantize_method = quantize_method + + def get_name(self) -> str: + return "neuron_quant" + + def get_supported_act_dtypes(self) -> List[str]: + return SUPPORTED_QUANT_DTYPE_LIST + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError( + "This function should not be called with Neuron Backend") + + @staticmethod + def get_config_filenames() -> List[str]: + return [] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "NeuronQuantConfig": + quantize_method = cls.get_from_keys(config, ["quantize_method"]) + dequant_dtype = cls.get_from_keys(config, ["dequant_dtype"]) + return cls(dequant_dtype=dequant_dtype, + quantize_method=quantize_method) + + def get_quant_method(self, layer: Module, prefix: str) -> Optional[Any]: + if find_spec("transformers_neuronx") is not None: + return self.get_quantization_config() + else: + raise NotImplementedError( + "Neuron Quantization is only supported through" + " transformers_neuronx.") + + def get_scaled_act_names(self) -> List[str]: + return [] + + def get_quantization_config(self): + from transformers_neuronx.config import QuantizationConfig + return QuantizationConfig(quant_dtype=self.quant_dtype, + dequant_dtype=self.dequant_dtype, + quantize_method=self.quantize_method) diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py deleted file mode 100644 index afb3c04976737..0000000000000 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ /dev/null @@ -1,138 +0,0 @@ -from typing import Any, Dict, List, Optional - -import torch -from torch.nn.parameter import Parameter - -from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import LinearBase -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.utils import set_weight_attrs -from vllm.utils import is_hip - - -class SqueezeLLMConfig(QuantizationConfig): - """Config class for SqueezeLLM. - - Reference: https://arxiv.org/pdf/2306.07629 - """ - - def __init__( - self, - weight_bits: int, - ) -> None: - self.weight_bits = weight_bits - - if self.weight_bits != 4: - raise ValueError( - "Currently, only 4-bit weight quantization is supported for " - f"SqueezeLLM, but got {self.weight_bits} bits.") - - self.pack_factor = 32 // self.weight_bits - - def __repr__(self) -> str: - return f"SqueezeLLMConfig(weight_bits={self.weight_bits})" - - def get_name(self) -> str: - return "squeezellm" - - def get_supported_act_dtypes(self) -> List[torch.dtype]: - return [torch.half] - - @classmethod - def get_min_capability(cls) -> int: - return 70 - - @staticmethod - def get_config_filenames() -> List[str]: - return ["quant_config.json"] - - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig": - weight_bits = cls.get_from_keys(config, ["wbits"]) - return cls(weight_bits) - - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional[QuantizeMethodBase]: - if isinstance(layer, LinearBase): - return SqueezeLLMLinearMethod(self) - return None - - def get_scaled_act_names(self) -> List[str]: - return [] - - -class SqueezeLLMLinearMethod(QuantizeMethodBase): - """Linear method for SqueezeLLM. - - Args: - quant_config: The SqueezeLLM quantization config. - """ - - def __init__(self, quant_config: SqueezeLLMConfig): - self.quant_config = quant_config - - def create_weights(self, layer: torch.nn.Module, - input_size_per_partition: int, - output_partition_sizes: List[int], input_size: int, - output_size: int, params_dtype: torch.dtype, - **extra_weight_attrs): - if input_size_per_partition % self.quant_config.pack_factor != 0: - raise ValueError( - "The input size is not aligned with the quantized " - "weight shape. This can be caused by too large " - "tensor parallel size.") - - output_size_per_partition = sum(output_partition_sizes) - qweight = Parameter( - torch.empty( - input_size_per_partition // self.quant_config.pack_factor, - output_size_per_partition, - dtype=torch.int32, - ), - requires_grad=False, - ) - set_weight_attrs( - qweight, { - "input_dim": 0, - "output_dim": 1, - "packed_dim": 0, - "pack_factor": self.quant_config.pack_factor, - }) - lookup_table = Parameter( - torch.empty( - output_size, - self.quant_config.weight_bits**2, - dtype=params_dtype, - ), - requires_grad=False, - ) - set_weight_attrs(lookup_table, { - "output_dim": 0, - }) - - layer.register_parameter("qweight", qweight) - set_weight_attrs(qweight, extra_weight_attrs) - layer.register_parameter("lookup_table", lookup_table) - set_weight_attrs(lookup_table, extra_weight_attrs) - - def apply(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qweight = layer.qweight - lookup_table = layer.lookup_table - out_shape = x.shape[:-1] + (qweight.shape[-1], ) - reshaped_x = x.reshape(-1, x.shape[-1]) - if is_hip(): - out_f = torch.zeros(out_shape, dtype=torch.float) - ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table) - out = out_f.to(dtype=torch.float16) - else: - # NOTE: The output tensor should be zero-initialized. - out = torch.zeros(out_shape, dtype=torch.float16) - ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table) - - if bias is not None: - out.add_(bias) - return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/tpu_int8.py b/vllm/model_executor/layers/quantization/tpu_int8.py index ae34e01497db4..be8235b468f68 100644 --- a/vllm/model_executor/layers/quantization/tpu_int8.py +++ b/vllm/model_executor/layers/quantization/tpu_int8.py @@ -7,7 +7,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.parameter import ModelWeightParameter ACTIVATION_SCHEMES = ["none"] @@ -64,16 +64,16 @@ def create_weights(self, layer: Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - weight = Parameter(torch.empty(sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - requires_grad=False) + + weight_loader = extra_weight_attrs.get("weight_loader") + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) layer.register_parameter("weight", weight) - set_weight_attrs(weight, { - **extra_weight_attrs, - "input_dim": 1, - "output_dim": 0, - }) def _quantize_weight( self, weight: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -92,6 +92,7 @@ def _quantize_weight( return qweight, qscale def process_weights_after_loading(self, layer: Module) -> None: + layer.weight = Parameter(layer.weight.data, requires_grad=False) device = layer.weight.device qweight, qscale = self._quantize_weight(layer.weight) qweight = qweight.to(device) diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/vllm/model_executor/layers/quantization/utils/marlin_utils.py index 0ec68ac5b0f21..699d5f1844146 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -176,6 +176,23 @@ def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, return s +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, num_bits: int) -> torch.Tensor: # Permute zero-points in a similar way to scales, but do not use the diff --git a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py index 7d08ac6f87469..4a06c5d63d52d 100644 --- a/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +++ b/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -1,6 +1,6 @@ """Utility functions used for tests and benchmarks""" -from typing import List +from typing import List, Optional import numpy as np import torch @@ -92,8 +92,11 @@ def get_weight_perm(num_bits: int): return perm -def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int, - act_order: bool): +def marlin_quantize(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None): size_k, size_n = w.shape num_bits = quant_type.size_bits @@ -104,7 +107,7 @@ def marlin_quantize(w: torch.Tensor, quant_type: ScalarType, group_size: int, # Quantize (and apply act_order if provided) w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( - w, quant_type, group_size, act_order) + w, quant_type, group_size, act_order, test_perm) # For act_order, sort the "weights" and "g_idx" so that group ids are # increasing diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 33f24ff5d54d3..bdfda31de852b 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -1,5 +1,5 @@ """This file is used for /tests and /benchmarks""" -from typing import List +from typing import List, Optional import numpy import torch @@ -53,7 +53,10 @@ def get_pack_factor(num_bits): return 32 // num_bits -def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): +def permute_rows(q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None): assert q_w.shape == w_ref.shape orig_device = q_w.device @@ -64,7 +67,7 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): g_idx[i] = i // group_size # Simulate act_order by doing a random permutation on K - rand_perm = torch.randperm(k_size) + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) g_idx = g_idx[rand_perm].contiguous() q_w = q_w[rand_perm, :].contiguous() @@ -164,8 +167,11 @@ def reshape_w(w): ) -def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType, - group_size: int, act_order: bool): +def gptq_quantize_weights(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None): size_k, _ = w.shape assert w.is_floating_point(), "w must be float" @@ -186,7 +192,8 @@ def gptq_quantize_weights(w: torch.Tensor, quant_type: ScalarType, ), "For act_order, groupsize = {} must be less than size_k = {}".format( group_size, size_k) - w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size) + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, + test_perm) return w_ref, w_q, w_s, g_idx, rand_perm diff --git a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py index 6cc1c65ddfa82..a54e3cae73b14 100644 --- a/vllm/model_executor/layers/quantization/utils/w8a8_utils.py +++ b/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -1,10 +1,8 @@ from typing import List, Optional, Tuple, Union import torch -from torch.nn import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import is_hip @@ -38,31 +36,6 @@ def all_close_1d(x: torch.Tensor) -> bool: return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) -def create_per_tensor_scale_param( - output_partition_sizes: List[int], - **extra_weight_attrs, -) -> Parameter: - scale = Parameter(torch.empty(len(output_partition_sizes), - dtype=torch.float32), - requires_grad=False) - scale[:] = torch.finfo(torch.float32).min - set_weight_attrs(scale, { - "needs_scalar_to_array": True, - **extra_weight_attrs - }) - return scale - - -def create_per_channel_scale_param(output_partition_sizes: List[int], - **extra_weight_attrs) -> Parameter: - scale = Parameter(torch.empty((sum(output_partition_sizes), 1), - dtype=torch.float32), - requires_grad=False) - scale[:] = torch.finfo(torch.float32).min - set_weight_attrs(scale, {"output_dim": 0, **extra_weight_attrs}) - return scale - - def convert_to_channelwise( weight_scale: torch.Tensor, logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: diff --git a/vllm/model_executor/layers/rejection_sampler.py b/vllm/model_executor/layers/rejection_sampler.py index 2124196d06f9c..b2f333a5bcc80 100644 --- a/vllm/model_executor/layers/rejection_sampler.py +++ b/vllm/model_executor/layers/rejection_sampler.py @@ -1,12 +1,28 @@ from functools import cached_property +from importlib.util import find_spec from typing import Dict, List, Optional, Tuple import torch import torch.jit +import vllm.envs as envs +from vllm.logger import init_logger from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeStochasticBaseSampler) +logger = init_logger(__name__) + +if find_spec("flashinfer"): + """ + Consider utilizing the FlashInfer rejection sampling kernel initially, + as it employs a dedicated kernel rather than relying on + Torch tensor operations. This design choice helps to fuse operations, + reduce memory I/O, and consequently enhances performance. + """ + from flashinfer.sampling import chain_speculative_sampling +else: + chain_speculative_sampling = None + class RejectionSampler(SpecDecodeStochasticBaseSampler): """Apply modified rejection sampling as described in "Accelerating Large @@ -16,7 +32,8 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler): def __init__(self, disable_bonus_tokens: bool = True, - strict_mode: bool = False): + strict_mode: bool = False, + use_flashinfer: Optional[bool] = None): """Create a rejection sampler. Args: @@ -26,13 +43,29 @@ def __init__(self, strict_mode: Whether or not to perform shape/device/dtype checks during sampling. This catches correctness issues but adds nontrivial latency. + use_falshinfer: We will use this parameter to determine whether + to use the FlashInfer rejection sampling kernel or not. If it's + None, we will use the default value from the environment variable. + This parameter is only used for testing purposes. """ super().__init__(disable_bonus_tokens=disable_bonus_tokens, strict_mode=strict_mode) + if use_flashinfer is None: + self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and ( + chain_speculative_sampling is not None) + else: + self.use_flashinfer = use_flashinfer + + if self.use_flashinfer: + assert not disable_bonus_tokens, \ + "flashinfer will enable bonus token by default" + logger.info("Use flashinfer for rejection sampling.") + else: + logger.info("Use pytorch for rejection sampling.") def forward( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, @@ -50,9 +83,9 @@ def forward( sequence. Args: - target_probs: The probability distribution over token ids given - context according to the target model. - shape = [batch_size, num_speculative_tokens, vocab_size] + target_with_bonus_probs: The probability distribution + over token ids given context according to the target model. + shape = [batch_size, num_speculative_tokens + 1, vocab_size] bonus_token_ids: The "bonus" token ids that are accepted iff all speculative tokens in a sequence are accepted. @@ -78,23 +111,52 @@ def forward( # Only perform shape/dtype/device checking in strict mode, as it adds # overhead. if self._strict_mode: - self._raise_if_incorrect_input(target_probs, draft_token_ids, - bonus_token_ids, draft_probs) + self._raise_if_incorrect_input(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) - accepted, recovered_token_ids = ( - self._batch_modified_rejection_sampling( - target_probs, - draft_probs, - draft_token_ids, - seeded_seqs, - )) + batch_size, k, _ = draft_probs.shape - output_token_ids = self._create_output( - accepted, - recovered_token_ids, - draft_token_ids, - bonus_token_ids, - ) + # batch_size = 0 when all requests in the batch are + # non_spec requests. In this case, output_token_ids is + # just an empty tensor. + if batch_size == 0: + return torch.empty(0, k + 1, device=draft_probs.device, dtype=int) + + # If use Flashinfer chain_speculative_sampling kernel + # for rejection sampling + if self.use_flashinfer: + batch_size, k, _ = draft_probs.shape + uniform_samples = self._create_uniform_samples( + seeded_seqs, batch_size, k, draft_probs.device) + output_token_ids, accepted_token_num, emitted_token_num \ + = chain_speculative_sampling( + draft_probs, draft_token_ids, uniform_samples, + target_with_bonus_probs) + + # num_emitted_tokens returned by flashinfer + # does not include the bonus token + # Flashinfer stops at the first token that violates + # the condition p >= q and does not include recovery/bonus token. + # Therefore, we need to add batch_size here. + self.num_accepted_tokens += accepted_token_num.sum() + self.num_emitted_tokens += emitted_token_num.sum() + batch_size + self.num_draft_tokens += batch_size * k + else: + accepted, recovered_token_ids = ( + self._batch_modified_rejection_sampling( + target_with_bonus_probs[:, :-1], + draft_probs, + draft_token_ids, + seeded_seqs, + )) + + output_token_ids = self._create_output( + accepted, + recovered_token_ids, + draft_token_ids, + bonus_token_ids, + ) return output_token_ids @@ -135,6 +197,63 @@ def _batch_modified_rejection_sampling( return accepted, recovered_token_ids + def _create_uniform_samples(self, + seeded_seqs: Optional[Dict[int, + torch.Generator]], + batch_size: int, k: int, + device: torch.device) -> torch.Tensor: + """ + Generates a batch of uniform random samples, with optional seeding + for specific sequences. + + This method creates a tensor of shape `(batch_size, k + 1)` filled + with uniform random values in the range [0, 1). If `seeded_seqs` + is provided, the sequences corresponding to specific indices + will be generated using the provided `torch.Generator` for + reproducibility. The other sequences will be generated without + a seed. + + Args: + seeded_seqs : Optional[Dict[int, torch.Generator]] + A dictionary mapping indices in the batch to + `torch.Generator` objects. If `None`, all samples are + generated without a seed. + batch_size : int + The number of sequences to generate. + k : int + The number of random samples per sequence. + device : torch.device + The device on which to allocate the tensor. + + Returns: + uniform_rand : torch.Tensor + A tensor of shape `(batch_size, k + 1)` containing uniform + random values in the range [0, 1). + """ + if not seeded_seqs: + return torch.rand(batch_size, k + 1, device=device) + + uniform_rand = torch.empty(batch_size, k + 1, device=device) + + non_seeded_indices = [] + for idx in range(batch_size): + generator = seeded_seqs.get(idx) + if generator is None: + non_seeded_indices.append(idx) + else: + uniform_rand[idx, :] = torch.rand(1, + k + 1, + dtype=self.probs_dtype, + device=device, + generator=generator) + if non_seeded_indices: + uniform_rand[non_seeded_indices, :] = torch.rand( + len(non_seeded_indices), + k + 1, + dtype=self.probs_dtype, + device=device) + return uniform_rand + def _get_accepted( self, target_probs: torch.Tensor, # [batch_size, k, vocab_size] @@ -175,29 +294,8 @@ def _get_accepted( selected_target_probs = target_probs[batch_indices, probs_indicies, draft_token_ids] - if not seeded_seqs: - uniform_rand = torch.rand_like(selected_target_probs) - else: - uniform_rand = torch.empty_like(selected_target_probs) - - non_seeded_indices = [] - for idx in range(batch_size): - generator = seeded_seqs.get(idx) - if generator is None: - non_seeded_indices.append(idx) - else: - uniform_rand[idx, :] = torch.rand( - 1, - k, - dtype=self.probs_dtype, - device=target_probs.device, - generator=generator) - if non_seeded_indices: - uniform_rand[non_seeded_indices, :] = torch.rand( - len(non_seeded_indices), - k, - dtype=self.probs_dtype, - device=target_probs.device) + uniform_rand = self._create_uniform_samples(seeded_seqs, batch_size, + k - 1, target_probs.device) capped_ratio = torch.minimum( selected_target_probs / selected_draft_probs, diff --git a/vllm/model_executor/layers/resampler.py b/vllm/model_executor/layers/resampler.py new file mode 100644 index 0000000000000..8cd938fc85fb2 --- /dev/null +++ b/vllm/model_executor/layers/resampler.py @@ -0,0 +1,273 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# https://huggingface.co/Qwen/Qwen-7B/blob/main/modeling_qwen.py +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +# +# Copyright 2023 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Shared resampler perceiver network used in multimodal models and +related helpers for sincos positional embeddings. + +Example models: Qwen (Qwen-VL), Minicpmv2.0 +""" +import math +from functools import partial +from typing import Callable, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.init import trunc_normal_ + +from vllm.model_executor.layers.linear import ReplicatedLinear + +DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) + + +def get_abs_pos(abs_pos: torch.Tensor, tgt_size: Union[torch.Tensor, + int]) -> torch.Tensor: + # abs_pos: L, C + # tgt_size: (H, W) + # return: M, C + src_size = int(math.sqrt(abs_pos.size(0))) + dtype = abs_pos.dtype + if isinstance(tgt_size, int): + tgt_size = (tgt_size, tgt_size) + if (src_size == tgt_size[0] and src_size == tgt_size[1]): + return abs_pos + return (F.interpolate( + abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), + size=(tgt_size[0], tgt_size[1]), + mode="bicubic", + align_corners=False, + ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)) + + +# sin/cos positional embedding helpers are adapted from: +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +def get_1d_sincos_pos_embed_from_grid( + embed_dim: int, pos: np.ndarray, + version: Tuple[int, int] = (2, 0)) -> torch.Tensor: + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) / (H, W) + out: (M, D) / (H, W, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + if version == (2, 0): + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + else: + out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product + emb_sin = np.sin(out) # (H, W, D/2) + emb_cos = np.cos(out) # (H, W, D/2) + emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D) + return emb + + +def get_2d_sincos_pos_embed_from_grid( + embed_dim: int, grid: np.ndarray, + version: Tuple[int, int] = (2, 0)) -> torch.Tensor: + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2) + + if version == (2, 0): + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + else: + emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D) + return emb + + +def get_2d_sincos_pos_embed( + embed_dim: int, + grid_size: Union[int, Tuple[int, int]], + cls_token: bool = False, + version: Tuple[int, int] = (2, 0), +) -> torch.Tensor: + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_h_size, grid_w_size = grid_size, grid_size + else: + grid_h_size, grid_w_size = grid_size[0], grid_size[1] + + grid_h = np.arange(grid_h_size, dtype=np.float32) + grid_w = np.arange(grid_w_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + assert isinstance(grid, np.ndarray) and \ + grid.shape == (2, grid_h_size, grid_w_size) + + if version == (2, 0): + grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], + axis=0) + else: + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) + return pos_embed + + +class BaseResampler(nn.Module): + """ + A 2D perceiver-resampler network with one cross attention layers by + (grid_size**2) learnable queries and 2d sincos pos_emb. + Outputs: + A tensor with the shape of (grid_size**2, embed_dim) + """ + + def __init__( + self, + num_queries: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + do_post_projection: bool = True, + ) -> None: + super().__init__() + + self.num_queries = num_queries + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) + trunc_normal_(self.query, std=0.02) + if kv_dim is not None and kv_dim != embed_dim: + self.kv_proj = ReplicatedLinear(kv_dim, embed_dim, bias=False) + else: + # Maintain the same return value with ReplicatedLinear.forward + self.kv_proj = lambda *args, **kwargs: ( # type: ignore # noqa + nn.Identity()(*args, **kwargs), + None, + ) + self.attn = nn.MultiheadAttention(embed_dim, num_heads) + self.ln_q = norm_layer(embed_dim) + self.ln_kv = norm_layer(embed_dim) + self.do_post_projection = do_post_projection + self.ln_post = norm_layer(embed_dim) if do_post_projection else None + self.proj = nn.Parameter( + (embed_dim**-0.5) * + torch.randn(embed_dim, embed_dim)) if do_post_projection else None + + def _init_weights(self, m: nn.Module) -> None: + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def _repeat(self, query, N: int): + return query.unsqueeze(1).repeat(1, N, 1) + + +class Resampler2(BaseResampler): + """Resampler-perceiver network to be used for a variety of model types, + e.g., Qwen-vl / Minicpmv 2.0. The main difference is the addition of the + do_post_projection arg, which indicates whether or not there should be + a post layer normalization and projector after the attention. This is + present in minicpmv2.0, but not qwen-vl. + """ + + def __init__( + self, + grid_size: int, + embed_dim: int, + num_heads: int, + kv_dim: Optional[int] = None, + norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, + adaptive: bool = False, + do_post_projection: bool = True, + ) -> None: + super().__init__(grid_size**2, + embed_dim, + num_heads, + kv_dim, + norm_layer, + do_post_projection=do_post_projection) + + self.adaptive = adaptive + pos_embed_arr = get_2d_sincos_pos_embed(embed_dim, + grid_size, + version=(2, 0)) + + self.pos_embed = nn.Parameter( + torch.from_numpy(pos_embed_arr).requires_grad_(False)) + + self.apply(self._init_weights) + + def forward( + self, + x: torch.Tensor, + tgt_sizes: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if tgt_sizes is None: + tgt_sizes = int(math.sqrt(x.size(1))) + if self.adaptive: + pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, + tgt_sizes, + version=(2, 0)) + pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device, + dtype=x.dtype) + else: + pos_embed = get_abs_pos(self.pos_embed, + tgt_sizes).to(device=x.device, + dtype=x.dtype) + + x, _ = self.kv_proj(x) + x = self.ln_kv(x).permute(1, 0, 2) + + N = x.shape[1] + q = self.ln_q(self.query) + out = self.attn( + self._repeat(q, N) + self.pos_embed.unsqueeze(1), + x + pos_embed.unsqueeze(1), + x, + attn_mask=attn_mask, + )[0] + x = out.permute(1, 0, 2) + if self.do_post_projection: + x = self.ln_post(x) + x = x @ self.proj + return x diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 0562b71aa7493..d323f6cc432a2 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -28,7 +28,6 @@ import torch.nn as nn from vllm.model_executor.custom_op import CustomOp -from vllm.platforms import current_platform def _rotate_neox(x: torch.Tensor) -> torch.Tensor: @@ -48,21 +47,29 @@ def _apply_rotary_emb( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, + is_neox_style: bool, ) -> torch.Tensor: """ Args: x: [num_tokens, num_heads, head_size] cos: [num_tokens, head_size // 2] sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. """ - orig_dtype = x.dtype - x = x.float() - x1, x2 = torch.chunk(x, 2, dim=-1) - cos = cos.unsqueeze(-2) - sin = sin.unsqueeze(-2) + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] o1 = x1 * cos - x2 * sin o2 = x2 * cos + x1 * sin - return torch.cat((o1, o2), dim=-1).to(orig_dtype) + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) class RotaryEmbedding(CustomOp): @@ -87,10 +94,9 @@ def __init__( cache = self._compute_cos_sin_cache() cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) - self.use_native2 = current_platform.is_tpu() and is_neox_style - def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: """Compute the inverse frequency.""" # NOTE(woosuk): To exactly match the HF implementation, we need to @@ -119,59 +125,7 @@ def forward_native( key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - """A PyTorch-native implementation equivalent to forward(). - - This method mimics the implementation of the custom CUDA kernel - used in `forward_cuda()`. - """ - query = query.view(*query.shape[:-1], -1, self.head_size) - key = key.view(*key.shape[:-1], -1, self.head_size) - - query_rot = query[..., :self.rotary_dim] - key_rot = key[..., :self.rotary_dim] - if self.rotary_dim < self.head_size: - query_pass = query[..., self.rotary_dim:] - key_pass = key[..., self.rotary_dim:] - - self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( - positions.device, dtype=query.dtype) - cos_sin = self.cos_sin_cache[torch.add(positions, offsets) - if offsets is not None else positions] - cos, sin = cos_sin.chunk(2, dim=-1) - if self.is_neox_style: - # NOTE(woosuk): Here we assume that the positions tensor has the - # shape [batch_size, seq_len]. - cos = cos.repeat(1, 1, 2).unsqueeze(-2) - sin = sin.repeat(1, 1, 2).unsqueeze(-2) - else: - cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) - sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) - - rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj - query_rot = query_rot * cos + rotate_fn(query_rot) * sin - key_rot = key_rot * cos + rotate_fn(key_rot) * sin - - if self.rotary_dim < self.head_size: - query = torch.cat((query_rot, query_pass), dim=-1) - key = torch.cat((key_rot, key_pass), dim=-1) - else: - query = query_rot - key = key_rot - query = query.flatten(-2) - key = key.flatten(-2) - return query, key - - def forward_native2( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - """Another PyTorch-native implementation of forward(). - - This method might perform better than `forward_native()` when compiled. - """ + """A PyTorch-native implementation of forward().""" if offsets is not None: positions = positions + offsets positions = positions.flatten() @@ -183,14 +137,14 @@ def forward_native2( query = query.view(num_tokens, -1, self.head_size) query_rot = query[..., :self.rotary_dim] query_pass = query[..., self.rotary_dim:] - query_rot = _apply_rotary_emb(query_rot, cos, sin) + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., :self.rotary_dim] key_pass = key[..., self.rotary_dim:] - key_rot = _apply_rotary_emb(key_rot, cos, sin) + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) return query, key @@ -203,7 +157,7 @@ def forward_cuda( ) -> Tuple[torch.Tensor, torch.Tensor]: from vllm import _custom_ops as ops - self.cos_sin_cache = self.cos_sin_cache.to(positions.device, + self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype) # ops.rotary_embedding()/batched_rotary_embedding() # are in-place operations that update the query and key tensors. @@ -240,17 +194,6 @@ def forward_xpu( self.cos_sin_cache, self.is_neox_style) return query, key - def forward_tpu( - self, - positions: torch.Tensor, - query: torch.Tensor, - key: torch.Tensor, - offsets: Optional[torch.Tensor] = None, - ) -> Tuple[torch.Tensor, torch.Tensor]: - forward_fn = (self.forward_native2 - if self.use_native2 else self.forward_native) - return forward_fn(positions, query, key, offsets) - def extra_repr(self) -> str: s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" s += f", max_position_embeddings={self.max_position_embeddings}" @@ -503,8 +446,8 @@ def __init__( dtype: torch.dtype, short_factor: List[float], long_factor: List[float], - short_mscale: float = 1.0, - long_mscale: float = 1.0, + short_mscale: Optional[float] = None, + long_mscale: Optional[float] = None, ): super().__init__() @@ -523,18 +466,22 @@ def __init__( self.base = base self.short_factor = short_factor self.long_factor = long_factor - self.short_mscale = short_mscale - self.long_mscale = long_mscale - - scale = (self.max_position_embeddings / - self.original_max_position_embeddings) + scale = self.max_position_embeddings / \ + self.original_max_position_embeddings if scale <= 1.0: - self.scaling_factor = 1.0 + scaling_factor = 1.0 else: - self.scaling_factor = math.sqrt( + scaling_factor = math.sqrt( 1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + if short_mscale is None: + short_mscale = scaling_factor + if long_mscale is None: + long_mscale = scaling_factor + + self.short_mscale = short_mscale + self.long_mscale = long_mscale short_cache = self._compute_cos_sin_cache( original_max_position_embeddings, short_factor, short_mscale) @@ -571,8 +518,8 @@ def _compute_cos_sin_cache( inv_freq = self._compute_inv_freq(rescale_factors) t = torch.arange(max_position_embeddings, dtype=torch.float) freqs = torch.einsum("i,j -> ij", t, inv_freq) - cos = freqs.cos() * mscale * self.scaling_factor - sin = freqs.sin() * mscale * self.scaling_factor + cos = freqs.cos() * mscale + sin = freqs.sin() * mscale cache = torch.cat((cos, sin), dim=-1) return cache diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 7344d59e988f0..c00da106734ae 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1,13 +1,16 @@ """A layer that samples the next tokens from the model's outputs.""" import itertools import warnings +from dataclasses import dataclass from importlib.util import find_spec from math import inf -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Optional, Tuple, Union +import msgspec import torch import torch.nn as nn +from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics from vllm.triton_utils import HAS_TRITON if HAS_TRITON: @@ -19,8 +22,7 @@ SequenceGroupToSample) from vllm.sampling_params import SamplingType from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - PromptLogprobs, SampleLogprobs, SamplerOutput, - SequenceOutput) + PromptLogprobs, SampleLogprobs, SequenceOutput) if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"): import flashinfer.sampling @@ -35,6 +37,116 @@ # (num_token_ids, num_parent_ids) per sequence group. SampleResultType = List[Tuple[List[int], List[int]]] +# Types of temporary data structures used for +# computing sample_result +SampleMetadataType = Dict[SamplingType, Tuple[List[int], + List[SequenceGroupToSample]]] +MultinomialSamplesType = Dict[SamplingType, torch.Tensor] +SampleResultsDictType = Dict[int, Tuple[List[int], List[int]]] + + +# Encapsulates temporary data structures for computing +# sample_result. +# +# * For multi-step scheduling: must be returned +# by `Sampler.forward()` and used later to compute the pythonized +# sample_result +# +# * For single-step scheduling: consumed immediately +# inside `Sampler.forward()` to compute pythonized sample_result. +@dataclass +class SampleResultArgsType: + sample_metadata: SampleMetadataType + multinomial_samples: MultinomialSamplesType + sample_results_dict: SampleResultsDictType + sampling_metadata: SamplingMetadata + greedy_samples: Optional[torch.Tensor] + beam_search_logprobs: Optional[torch.Tensor] + + +# Union of non-deferred (single-step scheduling) +# vs deferred (multi-step scheduling) +# sample result types +MaybeDeferredSampleResultType = Union[SampleResultType, SampleResultArgsType] + +# Abbreviation of the _sample() return type +SampleReturnType = Tuple[MaybeDeferredSampleResultType, Optional[torch.Tensor]] + + +class SamplerOutput( + msgspec.Struct, + omit_defaults=True, # type: ignore[call-arg] + array_like=True): # type: ignore[call-arg] + """For each sequence group, we generate a list of SequenceOutput object, + each of which contains one possible candidate for the next token. + + This data structure implements methods, so it can be used like a list, but + also has optional fields for device tensors. + """ + + outputs: List[CompletionSequenceGroupOutput] + + # On-device tensor containing probabilities of each token. + sampled_token_probs: Optional[torch.Tensor] = None + + # On-device tensor containing the logprobs of each token. + logprobs: Optional["torch.Tensor"] = None + + # Holds either (1) the pythonized sampler result (single-step scheduling) + # or (2) what will be arguments for later deferred pythonization of the + # sampler result (muliti-step scheduling) + deferred_sample_results_args: Optional[SampleResultArgsType] = None + + # On-device tensor containing the sampled token ids. + sampled_token_ids: Optional[torch.Tensor] = None + # CPU tensor containing the sampled token ids. Used during multi-step to + # return the sampled token ids from last rank to AsyncLLMEngine to be + # 'broadcasted' to all other PP ranks for next step. + sampled_token_ids_cpu: Optional[torch.Tensor] = None + + # Spec decode metrics populated by workers. + spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None + + # Optional last hidden states from the model. + hidden_states: Optional[torch.Tensor] = None + + # Optional prefill hidden states from the model + # (used for models like EAGLE). + prefill_hidden_states: Optional[torch.Tensor] = None + + # Time taken in the forward pass for this across all workers + model_forward_time: Optional[float] = None + + # Time taken in the model execute function. This will include model forward, + # block/sync across workers, cpu-gpu sync time and sampling time. + model_execute_time: Optional[float] = None + + def __getitem__(self, idx: int): + return self.outputs[idx] + + def __setitem__(self, idx: int, value): + self.outputs[idx] = value + + def __len__(self): + return len(self.outputs) + + def __eq__(self, other: object): + return isinstance(other, + self.__class__) and self.outputs == other.outputs + + def __repr__(self) -> str: + """Show the shape of a tensor instead of its values to reduce noise. + """ + sampled_token_probs_repr = ("None" if self.sampled_token_probs is None + else self.sampled_token_probs.shape) + sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else + self.sampled_token_ids.shape) + return ( + f"SamplerOutput(outputs={self.outputs}, " + f"sampled_token_probs={sampled_token_probs_repr}, " + f"sampled_token_ids={sampled_token_ids_repr}, " + f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") + class Sampler(nn.Module): """Samples the next tokens from the model's outputs. @@ -98,6 +210,19 @@ def forward( sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: """ + Single-step scheduling: + * Perform GPU-side sampling computation & compute + GPU-side logprobs tensor + * Pythonize sampling result & logprobs tensor + + Multi-step scheduling: + * Perform GPU-side sampling computation & compute + GPU-side logprobs tensor + * Defer Pythonization of sampling result & logprobs + tensor + * Encapsulate arguments required for deferred Pythonization + in the :class:`SamplerOutput` structure + Args: logits: (num_tokens, vocab_size). sampling_metadata: Metadata for sampling. @@ -150,7 +275,7 @@ def forward( logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) # Sample the next tokens. - sample_results, maybe_sampled_tokens_tensor = _sample( + maybe_deferred_sample_results, maybe_sampled_tokens_tensor = _sample( probs, logprobs, sampling_metadata, @@ -160,20 +285,28 @@ def forward( ) if self.include_gpu_probs_tensor: + # Since we will defer sampler result Pythonization, + # preserve GPU-side tensors in support of later + # deferred pythonization of logprobs assert maybe_sampled_tokens_tensor is not None on_device_tensors = (probs, logprobs, maybe_sampled_tokens_tensor) else: + # Since Pythonization has already happened, don't preserve + # GPU-side tensors. on_device_tensors = None # Get the logprobs query results. prompt_logprobs = None sample_logprobs = None if not sampling_metadata.skip_sampler_cpu_output: - prompt_logprobs, sample_logprobs = _get_logprobs( - logprobs, sampling_metadata, sample_results) + # Pythonize logprobs now (GPU -> CPU); do not defer. + assert not isinstance(maybe_deferred_sample_results, + SampleResultArgsType) + prompt_logprobs, sample_logprobs = get_logprobs( + logprobs, sampling_metadata, maybe_deferred_sample_results) return _build_sampler_output( - sample_results, + maybe_deferred_sample_results, sampling_metadata, prompt_logprobs, sample_logprobs, @@ -543,6 +676,60 @@ def _top_k_top_p_multinomial_with_flashinfer( return batch_next_token_ids.view(-1, num_samples) +def get_pythonized_sample_results( + sample_result_args: SampleResultArgsType) -> SampleResultType: + '''This function consumes GPU-side sampler results and computes + Pythonized CPU-side sampler results (GPU -> CPU sync.) + + Single-step scheduling: this function is invoked at sampling-time + for immediate Pythonization. + + Multi-step scheduling: Pythonization is deferred until after multiple + GPU-side steps have been completed. + + Args: + sample_result_args: GPU-side inputs to the Pythonization process + + Returns: + Pythonized sampler results + ''' + + ( + sample_metadata, + sampling_metadata, + greedy_samples, + multinomial_samples, + beam_search_logprobs, + sample_results_dict, + ) = ( + sample_result_args.sample_metadata, + sample_result_args.sampling_metadata, + sample_result_args.greedy_samples, + sample_result_args.multinomial_samples, + sample_result_args.beam_search_logprobs, + sample_result_args.sample_results_dict, + ) + + for sampling_type in SamplingType: + if sampling_type not in sample_metadata: + continue + (seq_group_id, seq_groups) = sample_metadata[sampling_type] + if sampling_type == SamplingType.GREEDY: + sample_results = _greedy_sample(seq_groups, greedy_samples) + elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED): + sample_results = _random_sample(seq_groups, + multinomial_samples[sampling_type]) + elif sampling_type == SamplingType.BEAM: + sample_results = _beam_search_sample(seq_groups, + beam_search_logprobs) + sample_results_dict.update(zip(seq_group_id, sample_results)) + + return [ + sample_results_dict.get(i, ([], [])) + for i in range(len(sampling_metadata.seq_groups)) + ] + + def _sample_with_torch( probs: torch.Tensor, logprobs: torch.Tensor, @@ -550,7 +737,19 @@ def _sample_with_torch( sampling_tensors: SamplingTensors, include_gpu_probs_tensor: bool, modify_greedy_probs: bool, -) -> Tuple[SampleResultType, Optional[torch.Tensor]]: +) -> SampleReturnType: + '''Torch-oriented _sample() implementation. + + Single-step scheduling: + * Perform GPU-side sampling computation + * Immediately Pythonize sampling result + + Multi-step scheduling: + * Perform GPU-side sampling computation + * Defer Pythonization & preserve GPU-side + tensors required for Pythonization + ''' + categorized_seq_group_ids: Dict[SamplingType, List[int]] = {t: [] for t in SamplingType} @@ -560,10 +759,11 @@ def _sample_with_torch( sampling_type = sampling_params.sampling_type categorized_seq_group_ids[sampling_type].append(i) - sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {} - sample_metadata: Dict[SamplingType, - Tuple[List[int], List[SequenceGroupToSample]]] = {} - multinomial_samples: Dict[SamplingType, torch.Tensor] = {} + sample_results_dict: SampleResultsDictType = {} + sample_metadata: SampleMetadataType = {} + multinomial_samples: MultinomialSamplesType = {} + greedy_samples: Optional[torch.Tensor] = None + beam_search_logprobs: Optional[torch.Tensor] = None # Create output tensor for sampled token ids. if include_gpu_probs_tensor: @@ -638,32 +838,29 @@ def _sample_with_torch( else: raise ValueError(f"Unsupported sampling type: {sampling_type}") - # GPU<->CPU sync happens in the loop below. - # This also converts the sample output to Python objects. + # Encapsulate arguments for computing Pythonized sampler + # results, whether deferred or otherwise. + maybe_deferred_args = SampleResultArgsType( + sampling_metadata=sampling_metadata, + sample_metadata=sample_metadata, + multinomial_samples=multinomial_samples, + greedy_samples=greedy_samples, + beam_search_logprobs=beam_search_logprobs, + sample_results_dict=sample_results_dict) + if not sampling_metadata.skip_sampler_cpu_output: - for sampling_type in SamplingType: - if sampling_type not in sample_metadata: - continue - (seq_group_id, seq_groups) = sample_metadata[sampling_type] - if sampling_type == SamplingType.GREEDY: - sample_results = _greedy_sample(seq_groups, greedy_samples) - elif sampling_type in (SamplingType.RANDOM, - SamplingType.RANDOM_SEED): - sample_results = _random_sample( - seq_groups, multinomial_samples[sampling_type]) - elif sampling_type == SamplingType.BEAM: - sample_results = _beam_search_sample(seq_groups, - beam_search_logprobs) - sample_results_dict.update(zip(seq_group_id, sample_results)) - - sample_results = [ - sample_results_dict.get(i, ([], [])) - for i in range(len(sampling_metadata.seq_groups)) - ] + # GPU<->CPU sync happens here. + # This also converts the sampler output to a Python object. + # Return Pythonized sampler result & sampled token ids + return get_pythonized_sample_results( + maybe_deferred_args), sampled_token_ids_tensor else: - sample_results = [] - - return sample_results, sampled_token_ids_tensor + # Defer sampler result Pythonization; return deferred + # Pythonization args & sampled token ids + return ( + maybe_deferred_args, + sampled_token_ids_tensor, + ) def _sample_with_triton_kernel( @@ -755,7 +952,7 @@ def _sample( sampling_tensors: SamplingTensors, include_gpu_probs_tensor: bool, modify_greedy_probs: bool, -) -> Tuple[SampleResultType, Optional[torch.Tensor]]: +) -> SampleReturnType: """ Args: probs: (num_query_tokens_in_batch, num_vocab) @@ -803,7 +1000,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: return result.sum(1).add_(1) -def _get_logprobs( +def get_logprobs( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, sample_results: SampleResultType, @@ -1126,7 +1323,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, def _build_sampler_output( - sample_results: SampleResultType, + maybe_deferred_sample_results: MaybeDeferredSampleResultType, sampling_metadata: SamplingMetadata, prompt_logprobs: Optional[List[Optional[PromptLogprobs]]], sample_logprobs: Optional[List[SampleLogprobs]], @@ -1143,14 +1340,21 @@ def _build_sampler_output( speculative decoding rejection sampling. """ sampler_output: List[CompletionSequenceGroupOutput] = [] - if not skip_sampler_cpu_output: + + if skip_sampler_cpu_output: + assert isinstance(maybe_deferred_sample_results, SampleResultArgsType) + deferred_sample_results_args = maybe_deferred_sample_results + else: assert prompt_logprobs is not None assert sample_logprobs is not None + assert not isinstance(maybe_deferred_sample_results, + SampleResultArgsType) + deferred_sample_results_args = None for (seq_group, sample_result, group_prompt_logprobs, group_sample_logprobs) in zip(sampling_metadata.seq_groups, - sample_results, prompt_logprobs, - sample_logprobs): + maybe_deferred_sample_results, + prompt_logprobs, sample_logprobs): seq_ids = seq_group.seq_ids next_token_ids, parent_ids = sample_result seq_outputs: List[SequenceOutput] = [] @@ -1176,7 +1380,7 @@ def _build_sampler_output( sampled_token_probs=sampled_token_probs, sampled_token_ids=sampled_token_ids, logprobs=logprobs_tensor, - ) + deferred_sample_results_args=deferred_sample_results_args) def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: diff --git a/vllm/model_executor/layers/spec_decode_base_sampler.py b/vllm/model_executor/layers/spec_decode_base_sampler.py index 467c43c41550e..f9532dffa92c0 100644 --- a/vllm/model_executor/layers/spec_decode_base_sampler.py +++ b/vllm/model_executor/layers/spec_decode_base_sampler.py @@ -130,29 +130,35 @@ def _create_output( def _raise_if_incorrect_input( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, draft_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: Optional[torch.Tensor] = None, ) -> None: - self._raise_if_incorrect_shape(target_probs, draft_token_ids, - bonus_token_ids, draft_probs) - self._raise_if_incorrect_dtype(target_probs, draft_token_ids, - bonus_token_ids, draft_probs) - self._raise_if_inconsistent_device(target_probs, draft_token_ids, - bonus_token_ids, draft_probs) - self._raise_if_out_of_bounds_vocab(target_probs.shape[-1], + self._raise_if_incorrect_shape(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) + self._raise_if_incorrect_dtype(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) + self._raise_if_inconsistent_device(target_with_bonus_probs, + draft_token_ids, bonus_token_ids, + draft_probs) + self._raise_if_out_of_bounds_vocab(target_with_bonus_probs.shape[-1], draft_token_ids, bonus_token_ids) def _raise_if_incorrect_shape( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, draft_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: Optional[torch.Tensor] = None, ) -> None: (target_batch_size, num_target_probs, - target_vocab_size) = target_probs.shape + target_vocab_size) = target_with_bonus_probs.shape + + # Does not count the extra token + num_target_probs -= 1 # validate the shape of draft token ids. draft_token_ids_batch_size, num_draft_token_ids = draft_token_ids.shape @@ -175,12 +181,12 @@ def _raise_if_incorrect_shape( def _raise_if_incorrect_dtype( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, draft_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: Optional[torch.Tensor] = None, ) -> None: - assert target_probs.dtype == self.probs_dtype + assert target_with_bonus_probs.dtype == self.probs_dtype assert draft_token_ids.dtype == self.token_id_dtype assert bonus_token_ids.dtype == self.token_id_dtype if draft_probs is not None: @@ -188,15 +194,16 @@ def _raise_if_incorrect_dtype( def _raise_if_inconsistent_device( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, draft_token_ids: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: Optional[torch.Tensor] = None, ) -> None: devices = [ - t.device for t in - [target_probs, bonus_token_ids, draft_probs, draft_token_ids] - if t is not None + t.device for t in [ + target_with_bonus_probs, bonus_token_ids, draft_probs, + draft_token_ids + ] if t is not None ] assert all([devices[0] == device for device in devices]) @@ -220,7 +227,7 @@ class SpecDecodeDeterministicBaseSampler(SpecDecodeBaseSampler): @abstractmethod def forward( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, @@ -236,7 +243,7 @@ class SpecDecodeStochasticBaseSampler(SpecDecodeBaseSampler): @abstractmethod def forward( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, diff --git a/vllm/model_executor/layers/typical_acceptance_sampler.py b/vllm/model_executor/layers/typical_acceptance_sampler.py index a87ea0eee57de..7428d33ea720d 100644 --- a/vllm/model_executor/layers/typical_acceptance_sampler.py +++ b/vllm/model_executor/layers/typical_acceptance_sampler.py @@ -41,7 +41,7 @@ def __init__( def forward( self, - target_probs: torch.Tensor, + target_with_bonus_probs: torch.Tensor, bonus_token_ids: torch.Tensor, draft_probs: torch.Tensor, draft_token_ids: torch.Tensor, @@ -80,8 +80,9 @@ def forward( # Only perform shape/dtype/device checking in strict mode, as it adds # overhead. if self._strict_mode: - self._raise_if_incorrect_input(target_probs, draft_token_ids, - bonus_token_ids) + self._raise_if_incorrect_input(target_with_bonus_probs, + draft_token_ids, bonus_token_ids) + target_probs = target_with_bonus_probs[:, :-1] accepted = self._evaluate_accepted_tokens(target_probs, draft_token_ids) recovered_token_ids = self._replacement_token_ids(target_probs) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 3ba15573c217b..ef6d401be2070 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -10,6 +10,7 @@ tensor_model_parallel_all_reduce) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding) +from vllm.model_executor.parameter import BasevLLMParameter from vllm.model_executor.utils import set_weight_attrs DEFAULT_VOCAB_PADDING_SIZE = 64 @@ -351,7 +352,10 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param.weight_type = loaded_weight.item() return elif isinstance(param, UninitializedParameter): - param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype) + shape = list(loaded_weight.shape) + if output_dim is not None: + shape[output_dim] = shape[output_dim] // self.tp_size + param.materialize(tuple(shape), dtype=loaded_weight.dtype) # If parameter does not have output dim, then it should # be copied onto all gpus (e.g. g_idx for act_order gptq). @@ -367,10 +371,12 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): # If param packed on the same dim we are sharding on, then # need to adjust offsets of loaded weight by pack_factor. if packed_dim is not None and packed_dim == output_dim: + packed_factor = param.packed_factor if isinstance( + param, BasevLLMParameter) else param.pack_factor assert loaded_weight.shape[output_dim] == (self.org_vocab_size // - param.pack_factor) - start_idx = start_idx // param.pack_factor - shard_size = shard_size // param.pack_factor + param.packed_factor) + start_idx = start_idx // packed_factor + shard_size = shard_size // packed_factor else: assert loaded_weight.shape[output_dim] == self.org_vocab_size diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 2f6cdbc6ce3e9..f59eb805ea907 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -17,6 +17,7 @@ from huggingface_hub import HfApi, hf_hub_download from torch import nn from transformers import AutoModelForCausalLM, PretrainedConfig +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoadFormat, LoRAConfig, ModelConfig, MultiModalConfig, @@ -184,6 +185,11 @@ class BaseModelLoader(ABC): def __init__(self, load_config: LoadConfig): self.load_config = load_config + @abstractmethod + def download_model(self, model_config: ModelConfig) -> None: + """Download a model so that it can be immediately loaded.""" + raise NotImplementedError + @abstractmethod def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, @@ -192,7 +198,7 @@ def load_model(self, *, model_config: ModelConfig, scheduler_config: SchedulerConfig, cache_config: CacheConfig) -> nn.Module: """Load a model with the given configurations.""" - ... + raise NotImplementedError class DefaultModelLoader(BaseModelLoader): @@ -241,12 +247,17 @@ def _prepare_weights(self, model_name_or_path: str, is_local = os.path.isdir(model_name_or_path) load_format = self.load_config.load_format use_safetensors = False + index_file = SAFE_WEIGHTS_INDEX_NAME # Some quantized models use .pt files for storing the weights. if load_format == LoadFormat.AUTO: allow_patterns = ["*.safetensors", "*.bin"] elif load_format == LoadFormat.SAFETENSORS: use_safetensors = True allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.MISTRAL: + use_safetensors = True + allow_patterns = ["consolidated*.safetensors"] + index_file = "consolidated.safetensors.index.json" elif load_format == LoadFormat.PT: allow_patterns = ["*.pt"] elif load_format == LoadFormat.NPCACHE: @@ -284,10 +295,10 @@ def _prepare_weights(self, model_name_or_path: str, # any files not found in the index. if not is_local: download_safetensors_index_file_from_hf( - model_name_or_path, self.load_config.download_dir, - revision) + model_name_or_path, index_file, + self.load_config.download_dir, revision) hf_weights_files = filter_duplicate_safetensors_files( - hf_weights_files, hf_folder) + hf_weights_files, hf_folder, index_file) else: hf_weights_files = filter_files_not_needed_for_inference( hf_weights_files) @@ -329,6 +340,11 @@ def _xla_weights_iterator(iterator: Generator): weights_iterator = _xla_weights_iterator(weights_iterator) return weights_iterator + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, + model_config.revision, + fall_back_to_pt=True) + def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], @@ -371,6 +387,9 @@ def __init__(self, load_config: LoadConfig): raise ValueError(f"Model loader extra config is not supported for " f"load format {load_config.load_format}") + def download_model(self, model_config: ModelConfig) -> None: + pass # Nothing to download + def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], @@ -461,6 +480,12 @@ def _load_model_serialized( model = load_with_tensorizer(tensorizer_config, **extra_kwargs) return model.eval() + def download_model(self, model_config: ModelConfig) -> None: + self.tensorizer_config.verify_with_model_config(model_config) + + with self.tensorizer_config.open_stream(): + pass + def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], @@ -562,6 +587,9 @@ def _prepare_weights(self, model_name_or_path: str, ignore_patterns=self.load_config.ignore_patterns, ) + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, model_config.revision) + def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], @@ -771,7 +799,11 @@ def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): return pt_weights_iterator(hf_weights_files) def _get_quantized_weights_iterator( - self, model_name_or_path: str, revision: Optional[str], pre_quant: bool + self, + model_name_or_path: str, + revision: Optional[str], + pre_quant: bool, + load_8bit: bool, ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, Any]]: """Get an iterator to the model weights with bitsandbytes quantization, @@ -780,11 +812,9 @@ def _get_quantized_weights_iterator( # only load the bitsandbytes module when needed try: import bitsandbytes - from bitsandbytes.functional import QuantState if bitsandbytes.__version__ < "0.42.0": raise ImportError("bitsandbytes version is wrong. Please " "install bitsandbytes>=0.42.0.") - from bitsandbytes.functional import quantize_4bit except ImportError as err: raise ImportError("Please install bitsandbytes>=0.42.0 via " "`pip install bitsandbytes>=0.42.0` to use " @@ -793,80 +823,111 @@ def _get_quantized_weights_iterator( hf_weights_files, use_safetensors = self._prepare_weights( model_name_or_path, revision) - quant_state_dict = {} - - def quantized_checkpoint() -> Generator: - # First iterate over all quant state weights - weight_iterator = self._hf_weight_iter(hf_weights_files, - use_safetensors) - temp_state_dict = {} - for weight_name, weight_tensor in weight_iterator: - if weight_name.endswith(".weight"): - continue - # TODO: only nf4 quantization is supported for now - if weight_name.endswith(".quant_state.bitsandbytes__fp4"): - raise NotImplementedError( - "Only bitsandbytes_nf4 quantization" - f"is supported for now. {weight_name} is fp4 quantized" - ) - temp_state_dict[weight_name] = weight_tensor + quant_state_dict: Dict[str, Any] = {} - # Closure to parse quant_state for each prequant weight - def _parse_quant_state(param_name: str, - temp_state_dict: Dict) -> QuantState: - quant_state = {} - for k in temp_state_dict: - if param_name + "." in k: - quant_state[k] = temp_state_dict[k] - # bitsandbytes library requires - # weight.quant_state.bitsandbytes__nf4 in CPU - quant_state[param_name + - ".quant_state.bitsandbytes__nf4"] = quant_state[ - param_name + - ".quant_state.bitsandbytes__nf4"].cpu().data - return QuantState.from_dict(quant_state, device="cuda") - - # Second iterate over all prequant and normal weights - # pre quantized weights would have a quant_state - for weight_name, weight_tensor in self._hf_weight_iter( - hf_weights_files, use_safetensors): - # Filter out all weights whose suffix is not ".weight" - if not weight_name.endswith(".weight"): - continue - if weight_name + ".quant_state.bitsandbytes__nf4" \ - in temp_state_dict: - quant_state = _parse_quant_state(weight_name, - temp_state_dict) - weight_name = weight_name.replace(".weight", ".qweight") - quant_state_dict[weight_name] = quant_state - yield weight_name.replace(".weight", - ".qweight"), weight_tensor - else: - yield weight_name, weight_tensor - - def generator() -> Generator: - for weight_name, weight_tensor in self._hf_weight_iter( - hf_weights_files, use_safetensors): - if any(target_module in weight_name - for target_module in self.target_modules): - weight_name = weight_name.replace(".weight", ".qweight") - # bitsandbytes requires data in GPU - loaded_weight = weight_tensor.cuda().data - with set_default_torch_dtype(torch.float32): - processed_weight, quant_state = quantize_4bit( - loaded_weight, - compress_statistics=True, - quant_type="nf4") - - quant_state_dict[weight_name] = quant_state - else: - processed_weight = weight_tensor + if pre_quant: + if load_8bit: + return self._quantized_8bit_generator( + hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict + else: + return self._quantized_4bit_generator( + hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict - yield weight_name, processed_weight + return self._unquantized_generator(hf_weights_files, use_safetensors, + quant_state_dict), quant_state_dict - if pre_quant: - return quantized_checkpoint(), quant_state_dict - return generator(), quant_state_dict + def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + if not weight_name.lower().endswith(".scb"): + continue + + weight_key = weight_name.lower().replace(".scb", ".qweight") + quant_state_dict[weight_key] = weight_tensor + + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + + if not weight_name.endswith(".weight"): + continue + + qweight_name = weight_name.replace(".weight", ".qweight") + if qweight_name in quant_state_dict: + set_weight_attrs(weight_tensor, {"load_in_8bit": True}) + yield qweight_name, weight_tensor + else: + yield weight_name, weight_tensor + + def _quantized_4bit_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + from bitsandbytes.functional import QuantState + + # First iterate over all quant state weights + weight_iterator = self._hf_weight_iter(hf_weights_files, + use_safetensors) + temp_state_dict = {} + for weight_name, weight_tensor in weight_iterator: + if weight_name.endswith(".weight"): + continue + # bitsandbytes library requires + # weight.quant_state.bitsandbytes__* in CPU + if "quant_state.bitsandbytes" in weight_name: + temp_state_dict[weight_name] = weight_tensor.cpu().data + else: + temp_state_dict[weight_name] = weight_tensor + + # Closure to parse quant_state for each prequant weight + def _parse_quant_state(param_name: str, + temp_state_dict: Dict) -> QuantState: + quant_state = {} + for k in temp_state_dict: + if param_name + "." in k: + quant_state[k] = temp_state_dict[k] + + return QuantState.from_dict(quant_state, device="cuda") + + # Second iterate over all prequant and normal weights + # pre quantized weights would have a quant_state + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + # Filter out all weights whose suffix is not ".weight" + if not weight_name.endswith(".weight"): + continue + if (f"{weight_name}.quant_state.bitsandbytes__nf4" \ + in temp_state_dict) or \ + (f"{weight_name}.quant_state.bitsandbytes__fp4" \ + in temp_state_dict): + quant_state = _parse_quant_state(weight_name, temp_state_dict) + weight_name = weight_name.replace(".weight", ".qweight") + quant_state_dict[weight_name] = quant_state + yield weight_name.replace(".weight", ".qweight"), weight_tensor + else: + yield weight_name, weight_tensor + + def _unquantized_generator(self, hf_weights_files, use_safetensors, + quant_state_dict) -> Generator: + from bitsandbytes.functional import quantize_4bit + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + if any(target_module in weight_name + for target_module in self.target_modules): + weight_name = weight_name.replace(".weight", ".qweight") + # bitsandbytes requires data in GPU + loaded_weight = weight_tensor.cuda().data + with set_default_torch_dtype(torch.float32): + processed_weight, quant_state = quantize_4bit( + loaded_weight, + compress_statistics=True, + quant_type="nf4") + + quant_state_dict[weight_name] = quant_state + else: + processed_weight = weight_tensor + + yield weight_name, processed_weight def _load_weights(self, model_config: ModelConfig, model: nn.Module) -> None: @@ -883,16 +944,26 @@ def _load_weights(self, model_config: ModelConfig, logger.info("Loading weights with BitsAndBytes quantization. " " May take a while ...") - is_quantized_checkpoint = False quant_config = getattr(model_config.hf_config, "quantization_config", None) - if quant_config is not None and quant_config.get( - 'quant_method') == "bitsandbytes": - is_quantized_checkpoint = True + + pre_quant = False + if quant_config is not None: + quant_method = quant_config.get('quant_method') + if quant_method == "bitsandbytes": + pre_quant = True + else: + raise ValueError( + f"BitsAndBytes loader does not support {quant_method} " + "quantization") + + load_8bit = False + if pre_quant: + load_8bit = quant_config.get('load_in_8bit', False) qweight_iterator, quant_state_dict = \ self._get_quantized_weights_iterator( - model_config.model, model_config.revision, is_quantized_checkpoint) + model_config.model, model_config.revision, pre_quant, load_8bit) model.load_weights(qweight_iterator) @@ -942,6 +1013,13 @@ def _load_weights(self, model_config: ModelConfig, offsets = np.concatenate(([0], np.cumsum(num_elements))) set_weight_attrs(param, {"bnb_shard_offsets": offsets}) + if load_8bit: + set_weight_attrs( + param, {"matmul_state": [None] * len(quant_states)}) + + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model, model_config.revision) + def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], @@ -1017,6 +1095,9 @@ def _get_weights_iterator( return gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map) + def download_model(self, model_config: ModelConfig) -> None: + self._prepare_weights(model_config.model) + def load_model(self, *, model_config: ModelConfig, device_config: DeviceConfig, lora_config: Optional[LoRAConfig], diff --git a/vllm/model_executor/model_loader/neuron.py b/vllm/model_executor/model_loader/neuron.py index 07e23aca6cc5f..594ae442ef328 100644 --- a/vllm/model_executor/model_loader/neuron.py +++ b/vllm/model_executor/model_loader/neuron.py @@ -1,7 +1,7 @@ """Utilities for selecting and loading neuron models.""" import importlib import os -from typing import Dict, Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn @@ -10,9 +10,9 @@ from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.quantization import get_quantization_config +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput TORCH_DTYPE_TO_NEURON_AMP = { "auto": "f32", @@ -82,8 +82,7 @@ def load_weights(self, model_name_or_path: str, **kwargs): neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name) split_model_dir = f"{model_name_or_path}-split" - if os.path.isdir(os.path.join(model_name_or_path, - "pytorch_model.bin")): + if _is_pretrained_neuron_checkpoint(model_name_or_path): split_model_dir = model_name_or_path elif not os.path.exists(f"{model_name_or_path}-split"): hf_model_cls = getattr(transformers, hf_model_cls_name) @@ -98,6 +97,23 @@ def load_weights(self, model_name_or_path: str, **kwargs): self.model.to_neuron() +def _is_pretrained_neuron_checkpoint(model_name_or_path: str) -> bool: + # Checking if the neuron checkpoint is saved in the old format. + if os.path.isdir(os.path.join(model_name_or_path, "pytorch_model.bin")): + return True + # Checking if the neuron checkpoint is saved in the new format. + pretrained_split_files = ["config.json", "generation_config.json"] + pretrained_split_format = ".safetensors" + for file in pretrained_split_files: + file_path = os.path.join(model_name_or_path, file) + if not os.path.isfile(file_path): + return False + for file in os.listdir(model_name_or_path): + if file.endswith(pretrained_split_format): + return True + return False + + def _get_model_architecture(config: PretrainedConfig) -> str: architectures = getattr(config, "architectures", []) for arch in architectures: @@ -109,28 +125,75 @@ def _get_model_architecture(config: PretrainedConfig) -> str: f"{list(_NEURON_SUPPORTED_MODELS.keys())}") +def _get_buckets(env: str, default_value: List[int]) -> List[int]: + env_value = os.getenv(env) + if env_value is None: + return default_value + buckets_remove_empty = filter( + lambda x: x is not None and len(x.strip()) > 0, env_value.split(",")) + buckets_int = map(int, buckets_remove_empty) + buckets_list = list(buckets_int) + return buckets_list + + +def _get_default_neuron_config(model_config: ModelConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig): + from transformers_neuronx.config import ContinuousBatchingConfig + from transformers_neuronx.constants import LAYOUT_BSH + + continuous_batching_config = ContinuousBatchingConfig( + batch_size_for_shared_caches=scheduler_config.max_num_seqs) + quant_config = dict( + dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + quantize_method="vector_dynamic") + neuron_quantization_config_builder = lambda quant: get_quantization_config( + quant).from_config(quant_config).get_quant_method(None, "") + # TODO: Add Paged attention config to the default neuron arguments. + default_neuron_args = dict( + collectives_layout=LAYOUT_BSH, + attention_layout=LAYOUT_BSH, + fuse_qkv=True, + quant=neuron_quantization_config_builder(model_config.quantization) + if model_config.quantization else None, + continuous_batching=continuous_batching_config, + weight_tiling=bool(model_config.quantization)) + return default_neuron_args + + +def _get_neuron_config_after_override(default_neuron_config, + overridden_neuron_config): + from transformers_neuronx.config import NeuronConfig + overridden_neuron_config = overridden_neuron_config or {} + default_neuron_config.update(overridden_neuron_config) + return NeuronConfig(**default_neuron_config) + + def get_neuron_model(model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig) -> nn.Module: - from transformers_neuronx.config import (ContinuousBatchingConfig, - NeuronConfig) # Create a model instance. model = NeuronCasualLM(model_config.hf_config) - continuous_batching_config = ContinuousBatchingConfig( - batch_size_for_shared_caches=scheduler_config.max_num_seqs) - neuron_config = NeuronConfig( - continuous_batching=continuous_batching_config) + default_neuron_config_args = _get_default_neuron_config( + model_config, parallel_config, scheduler_config) + + neuron_config = _get_neuron_config_after_override( + default_neuron_config_args, model_config.override_neuron_config) + + context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS", + [scheduler_config.max_model_len]) + n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS", + [scheduler_config.max_model_len]) # Load the weights from the cached or downloaded files. - model.load_weights( - model_config.model, - tp_degree=parallel_config.tensor_parallel_size, - amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], - neuron_config=neuron_config, - context_length_estimate=[scheduler_config.max_model_len], - n_positions=[scheduler_config.max_model_len], - batch_size=scheduler_config.max_num_seqs) + model.load_weights(model_config.model, + tp_degree=parallel_config.tensor_parallel_size, + amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype], + neuron_config=neuron_config, + context_length_estimate=context_length_estimates, + n_positions=n_positions, + batch_size=scheduler_config.max_num_seqs) return model.eval() diff --git a/vllm/model_executor/model_loader/openvino.py b/vllm/model_executor/model_loader/openvino.py index 5c522a61732a4..3c1f6fa769894 100644 --- a/vllm/model_executor/model_loader/openvino.py +++ b/vllm/model_executor/model_loader/openvino.py @@ -15,9 +15,8 @@ from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import (LogitsProcessor, _prune_hidden_states) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput logger = init_logger(__name__) diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index b009ad8c882d4..3aac5cd2b43a5 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -99,6 +99,13 @@ def verify_with_model_config(self, model_config: "ModelConfig") -> None: "Loading a model using Tensorizer with quantization on vLLM" " is unstable and may lead to errors.") + def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None): + if tensorizer_args is None: + tensorizer_args = self._construct_tensorizer_args() + + return open_stream(self.tensorizer_uri, + **tensorizer_args.stream_params) + def load_with_tensorizer(tensorizer_config: TensorizerConfig, **extra_kwargs) -> nn.Module: diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 4bb943ab3afe4..0052489d99dc4 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -24,10 +24,18 @@ def get_model_architecture( # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. mixtral_supported = ["fp8", "compressed-tensors"] + # for gptq_marlin, only run fused MoE for int4 + if model_config.quantization == "gptq_marlin": + hf_quant_config = getattr(model_config.hf_config, + "quantization_config", None) + if hf_quant_config and hf_quant_config.get("bits") == 4: + mixtral_supported.append("gptq_marlin") + if (model_config.quantization is not None and model_config.quantization not in mixtral_supported and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"] + return ModelRegistry.resolve_model_cls(architectures) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 0666457756b02..075451292a8e4 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -16,7 +16,6 @@ from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm -from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm.config import LoadConfig, ModelConfig from vllm.distributed import get_tensor_model_parallel_rank @@ -251,6 +250,7 @@ def download_weights_from_hf( def download_safetensors_index_file_from_hf( model_name_or_path: str, + index_file: str, cache_dir: Optional[str], revision: Optional[str] = None, ) -> None: @@ -269,36 +269,37 @@ def download_safetensors_index_file_from_hf( # Download the safetensors index file. hf_hub_download( repo_id=model_name_or_path, - filename=SAFE_WEIGHTS_INDEX_NAME, + filename=index_file, cache_dir=cache_dir, revision=revision, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, ) # If file not found on remote or locally, we should not fail since - # only some models will have SAFE_WEIGHTS_INDEX_NAME. + # only some models will have index_file. except huggingface_hub.utils.EntryNotFoundError: - logger.info("No %s found in remote.", SAFE_WEIGHTS_INDEX_NAME) + logger.info("No %s found in remote.", index_file) except huggingface_hub.utils.LocalEntryNotFoundError: - logger.info("No %s found in local cache.", SAFE_WEIGHTS_INDEX_NAME) + logger.info("No %s found in local cache.", index_file) # For models like Mistral-7B-v0.3, there are both sharded # safetensors files and a consolidated safetensors file. # Passing both of these to the weight loader functionality breaks. -# So, we use the SAFE_WEIGHTS_INDEX_NAME to +# So, we use the index_file to # look up which safetensors files should be used. def filter_duplicate_safetensors_files(hf_weights_files: List[str], - hf_folder: str) -> List[str]: + hf_folder: str, + index_file: str) -> List[str]: # model.safetensors.index.json is a mapping from keys in the # torch state_dict to safetensors file holding that weight. - index_file_name = os.path.join(hf_folder, SAFE_WEIGHTS_INDEX_NAME) + index_file_name = os.path.join(hf_folder, index_file) if not os.path.isfile(index_file_name): return hf_weights_files # Iterate through the weight_map (weight_name: safetensors files) # to identify weights that we should use. - with open(index_file_name) as index_file: - weight_map = json.load(index_file)["weight_map"] + with open(index_file_name, "r") as f: + weight_map = json.load(f)["weight_map"] weight_files_in_index = set() for weight_name in weight_map: weight_files_in_index.add( diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 8591c276b0013..4db847029566f 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -22,6 +22,7 @@ "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), + "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), "FalconForCausalLM": ("falcon", "FalconForCausalLM"), "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), @@ -49,7 +50,7 @@ "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), "PhiForCausalLM": ("phi", "PhiForCausalLM"), "Phi3ForCausalLM": ("llama", "LlamaForCausalLM"), - "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), + "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), "RWForCausalLM": ("falcon", "FalconForCausalLM"), @@ -63,6 +64,7 @@ "EAGLEModel": ("eagle", "EAGLE"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "GraniteForCausalLM": ("granite", "GraniteForCausalLM") } _EMBEDDING_MODELS = { @@ -85,6 +87,7 @@ "PaliGemmaForConditionalGeneration"), "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), "UltravoxModel": ("ultravox", "UltravoxModel"), + "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), } _CONDITIONAL_GENERATION_MODELS = { "BartModel": ("bart", "BartForConditionalGeneration"), diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 28f69cfbc46bd..efa044d0b5e92 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -23,13 +23,13 @@ from vllm.model_executor.layers.quantization.deepspeedfp import ( DeepSpeedFPConfig, DeepSpeedFPParameter) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.arctic import ArcticConfig logger = init_logger(__name__) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 73711d8eb5185..bdd76b11384c2 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -38,12 +38,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index f78400b0df7b3..9b4c4be7fcb09 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -34,12 +34,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors logger = logging.get_logger(__name__) diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py index 830680fd990bf..583d5d217903b 100644 --- a/vllm/model_executor/models/blip.py +++ b/vllm/model_executor/models/blip.py @@ -10,15 +10,23 @@ from transformers.models.blip.modeling_blip import BlipAttention from vllm.config import ModelConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import LLMInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +try: + from xformers import ops as xops + USE_XFORMERS_OPS = True +except ImportError: + USE_XFORMERS_OPS = False + def get_blip_patch_grid_length(*, image_size: int, patch_size: int) -> int: assert image_size % patch_size == 0 @@ -154,6 +162,77 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return embeddings +class BlipParallelAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: BlipVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + ) + self.projection = RowParallelLinear( + self.embed_dim, + self.embed_dim, + quant_config=quant_config, + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + ): + """Input shape: Batch x Time x Channel""" + bsz, tgt_len, _ = hidden_states.size() + + qkv_states, _ = self.qkv(hidden_states) + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + query_states = query_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + key_states = key_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + value_states = value_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + + out = xops.memory_efficient_attention_forward(query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale) + out = out.view(bsz, tgt_len, -1) + attn_output, _ = self.projection(out) + + return attn_output, None + + class BlipMLP(nn.Module): def __init__(self, @@ -188,7 +267,16 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None): super().__init__() - self.self_attn = BlipAttention(config) + # fallback to sdpa attention if tp unavailable + num_heads = config.num_attention_heads + tp_size = get_tensor_model_parallel_world_size() + if USE_XFORMERS_OPS and num_heads % tp_size == 0: + self.self_attn = BlipParallelAttention(config, + quant_config=quant_config) + else: + # Blip doesn't have SDPA attention implemented in transformers + # use eager attention instead for cpu backend + self.self_attn = BlipAttention(config) self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = BlipMLP(config, quant_config=quant_config) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 7c9123079c44f..39f2b2d853a6b 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -13,13 +13,13 @@ from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.opt import OPTModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SamplerOutput, SequenceData) + SequenceData) from .blip import (BlipVisionModel, dummy_image_for_blip, get_max_blip_image_tokens) @@ -40,13 +40,13 @@ class Blip2ImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor - """Shape: (batch_size, num_channels, height, width)""" + """Shape: `(batch_size * num_images, num_channels, height, width)`""" class Blip2ImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] data: torch.Tensor - """Shape: `(batch_size, image_feature_size, hidden_size)` + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ @@ -714,8 +714,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): use_default_weight_loading = False if "vision" in name: if self.vision_model is not None: - # We only do sharding for language model and - # not vision model for now. + # BlipVisionModel does not need sharding use_default_weight_loading = True else: for (param_name, weight_name, diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 07ee0e3c531d0..831b3f20457a9 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -34,12 +34,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 2d4f172ce0be6..47e020e8ecb73 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -22,7 +22,7 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( @@ -33,7 +33,7 @@ from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SamplerOutput, SequenceData) + SequenceData) from vllm.utils import print_warning_once from .interfaces import SupportsMultiModal @@ -53,7 +53,7 @@ class ChameleonImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor - """Shape: `(batch_size, num_channels, height, width)`""" + """Shape: `(batch_size * num_images, num_channels, height, width)`""" def get_max_chameleon_image_tokens(ctx: InputContext): diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 4949d0232fabb..35f1ed5ef5d33 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -20,12 +20,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import ChatGLMConfig from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 69bb9f6f3afee..70f1522ae2524 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -7,12 +7,14 @@ import torch.nn as nn from PIL import Image from transformers import CLIPVisionConfig -from transformers.models.clip.modeling_clip import CLIPAttention +from transformers.models.clip.modeling_clip import CLIPSdpaAttention from vllm.config import ModelConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import LLMInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -20,6 +22,12 @@ repeat_and_pad_placeholder_tokens) from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +try: + from xformers import ops as xops + USE_XFORMERS_OPS = True +except ImportError: + USE_XFORMERS_OPS = False + def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int: assert image_size % patch_size == 0 @@ -97,7 +105,7 @@ def input_processor_for_clip( if isinstance(image_data, Image.Image): image_feature_size = get_clip_image_feature_size(hf_config) elif isinstance(image_data, torch.Tensor): - image_feature_size = image_data.shape[0] + num_images, image_feature_size, hidden_size = image_data.shape else: raise TypeError(f"Invalid image type: {type(image_data)}") else: @@ -160,6 +168,78 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: return embeddings +class CLIPParallelAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: CLIPVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + "embed_dim must be divisible by num_heads " + f"(got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads}).") + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.qkv_proj = QKVParallelLinear( + hidden_size=self.embed_dim, + head_size=self.head_dim, + total_num_heads=self.num_heads, + quant_config=quant_config, + ) + + self.out_proj = RowParallelLinear( + input_size=self.embed_dim, + output_size=self.embed_dim, + quant_config=quant_config, + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, + self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + ): + """Input shape: Batch x Time x Channel""" + bsz, tgt_len, _ = hidden_states.size() + + qkv_states, _ = self.qkv_proj(hidden_states) + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) + + query_states = query_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + key_states = key_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + value_states = value_states.view(bsz, tgt_len, + self.num_heads_per_partition, + self.head_dim) + + out = xops.memory_efficient_attention_forward(query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale) + out = out.view(bsz, tgt_len, -1) + attn_output, _ = self.out_proj(out) + + return attn_output, None + + class CLIPMLP(nn.Module): def __init__(self, @@ -192,7 +272,13 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None): super().__init__() - self.self_attn = CLIPAttention(config) + num_heads = config.num_attention_heads + tp_size = get_tensor_model_parallel_world_size() + if USE_XFORMERS_OPS and num_heads % tp_size == 0: + self.self_attn = CLIPParallelAttention(config, + quant_config=quant_config) + else: + self.self_attn = CLIPSdpaAttention(config) self.layer_norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.mlp = CLIPMLP(config, quant_config=quant_config) @@ -291,6 +377,10 @@ def __init__(self, quant_config: Optional[QuantizationConfig] = None, num_hidden_layers_override: Optional[int] = None): super().__init__() + tp_size = get_tensor_model_parallel_world_size() + num_heads = config.num_attention_heads + self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 + self.vision_model = CLIPVisionTransformer( config=config, quant_config=quant_config, @@ -304,7 +394,15 @@ def forward(self, pixel_values: Optional[torch.Tensor] = None): def device(self): return next(self.parameters()).device + # (TODO) Add prefix argument for filtering out weights to be loaded + # ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986 def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] if self.shard_weight else [] params_dict = dict(self.named_parameters()) layer_count = len(self.vision_model.encoder.layers) @@ -318,7 +416,16 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if layer_idx >= layer_count: continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index f63cf246e510a..649dc798d22dc 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -38,14 +38,16 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, row_parallel_weight_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA @torch.compile @@ -292,8 +294,7 @@ def forward( return hidden_states -class CohereForCausalLM(nn.Module): - +class CohereForCausalLM(nn.Module, SupportsLoRA): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index dca959798e8b2..6160197dc19de 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -17,13 +17,13 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.dbrx import DbrxConfig diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 7a27e1388e987..61cc917ab6207 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -43,12 +43,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class DeepseekMLP(nn.Module): diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index c7f3af0ccb266..8cbd9435ec7ca 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -43,12 +43,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers diff --git a/vllm/model_executor/models/eagle.py b/vllm/model_executor/models/eagle.py index 99c825ff63572..ad1ab0231d861 100644 --- a/vllm/model_executor/models/eagle.py +++ b/vllm/model_executor/models/eagle.py @@ -5,12 +5,13 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models import ModelRegistry from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.eagle import EAGLEConfig diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py new file mode 100644 index 0000000000000..4a1c367de3f62 --- /dev/null +++ b/vllm/model_executor/models/exaone.py @@ -0,0 +1,617 @@ +# coding=utf-8 +# Adapted from +# https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/blob/main/modeling_exaone.py +# Copyright 2024 The LG U+ CTO AI Tech Lab. +# Copyright 2021 The LG AI Research EXAONE Lab +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Exaone model compatible with HuggingFace weights.""" + +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.exaone import ExaoneConfig +from vllm.utils import is_hip + +from .interfaces import SupportsLoRA +from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers + + +class ExaoneGatedMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) + self.c_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.c_proj", + ) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.c_proj(x) + return x + + +class ExaoneAttention(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.out_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.out_proj", + ) + + is_neox_style = True + if quant_config is not None and quant_config.get_name() == "gguf": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.out_proj(attn_output) + return output + + +class ExaoneBlockAttention(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.attention = ExaoneAttention( + config=config, + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=bias, + cache_config=cache_config, + prefix=prefix, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + return self.attention( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + +class ExaoneDecoderLayer(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.attn = ExaoneBlockAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.attn", + ) + self.mlp = ExaoneGatedMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.activation_function, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + else: + hidden_states, residual = self.ln_1(hidden_states, residual) + hidden_states = self.attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.ln_2(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class ExaoneModel(nn.Module): + + def __init__( + self, + config: ExaoneConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.wte = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.wte = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.wte = PPMissingLayer() + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: ExaoneDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.h", + ) + if get_pp_group().is_last_rank: + self.ln_f = RMSNorm(config.hidden_size, + eps=config.layer_norm_epsilon) + else: + self.ln_f = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.wte(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.h[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.ln_f(hidden_states, residual) + return hidden_states + + +class ExaoneForCausalLM(nn.Module, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "c_fc_0", + "c_fc_1", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "out_proj", + "gate_up_proj", + "c_proj", + "wte", + "lm_head", + ] + embedding_modules = { + "wte": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "c_fc_0": ("gate_up_proj", 0), + "c_fc_1": ("gate_up_proj", 1), + } + + def __init__( + self, + config: ExaoneConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.transformer = ExaoneModel( + config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model", + ) + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.transformer.wte.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = Sampler() + else: + self.lm_head = PPMissingLayer() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.transformer(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros( + (batch_size, self.config.hidden_size), + dtype=dtype, + device=device, + ), + "residual": + torch.zeros( + (batch_size, self.config.hidden_size), + dtype=dtype, + device=device, + ), + }) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".c_fc_0", 0), + (".gate_up_proj", ".c_fc_1", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + if scale_name := get_compressed_tensors_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, + tp_rank, + tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type, + ): + if not isinstance(self.transformer.h[layer_idx], nn.Identity): + layer_self_attn = self.transformer.h[layer_idx].attn + + if is_hip(): + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + scaling_factor *= 2 + if hasattr(layer_self_attn, "kv_scale"): + layer_self_attn.attn._kv_scale = scaling_factor + else: + raise RuntimeError("Self attention has no KV cache scaling " + "factor attribute!") diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 7b97b3d255dfa..b474d35baf89d 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -39,12 +39,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import RWConfig FalconConfig = Union[HF_FalconConfig, RWConfig] diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 6cdf331fed8b7..beeae14229575 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -31,6 +31,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.linear import ColumnParallelLinear from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.persimmon import PersimmonForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -39,7 +40,7 @@ from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SamplerOutput, SequenceData) + SequenceData) from .interfaces import SupportsMultiModal from .utils import merge_multimodal_embeddings diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index e1041edf81b0a..36fd389831282 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -34,12 +34,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 5e0f8b70d4b80..90449ec51ef0b 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -33,12 +33,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index bfc231282952a..fb5a297661ddc 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -34,12 +34,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .utils import is_pp_missing_parameter, make_layers diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index b93fb8d69b2d7..fe5ec10827608 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -34,12 +34,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 4d52b448049b4..664d775c8ba40 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -33,12 +33,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class GPTJAttention(nn.Module): diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 2adecf7fa9ef8..5f6f1e3880547 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -33,12 +33,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class GPTNeoXAttention(nn.Module): diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py new file mode 100644 index 0000000000000..b0325e8b616c8 --- /dev/null +++ b/vllm/model_executor/models/granite.py @@ -0,0 +1,543 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only IBM Granite model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.granite import GraniteConfig +from vllm.utils import is_hip + +from .interfaces import SupportsLoRA +from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers + + +class GraniteMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class GraniteAttention(nn.Module): + + def __init__( + self, + config: GraniteConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = config.attention_multiplier + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class GraniteDecoderLayer(nn.Module): + + def __init__( + self, + config: GraniteConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + self.residual_multiplier = config.residual_multiplier + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.self_attn = GraniteAttention( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + + self.mlp = GraniteMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + return hidden_states + + +class GraniteModel(nn.Module): + + def __init__( + self, + config: GraniteConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GraniteDecoderLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers") + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + hidden_states *= self.config.embedding_multiplier + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states = self.norm(hidden_states) + return hidden_states + + +class GraniteForCausalLM(nn.Module, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", + "lm_head" + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + + def __init__( + self, + config: GraniteConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = GraniteModel(config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model") + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = Sampler() + else: + self.lm_head = PPMissingLayer() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return model_output + + def compute_logits( + self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + logits /= self.config.logits_scaling + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + if scale_name := get_compressed_tensors_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, tp_rank, tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type): + if not isinstance(self.model.layers[layer_idx], nn.Identity): + layer_self_attn = self.model.layers[layer_idx].self_attn + + if is_hip(): + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + scaling_factor *= 2 + if hasattr(layer_self_attn, "kv_scale"): + layer_self_attn.attn._kv_scale = scaling_factor + else: + raise RuntimeError("Self attention has no KV cache scaling " + "factor attribute!") diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index 54c933e3e4959..33b4a3acaa559 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -11,13 +11,21 @@ import torch.nn.functional as F from transformers import PretrainedConfig +from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +try: + from xformers import ops as xops + USE_XFORMERS_OPS = True +except ImportError: + USE_XFORMERS_OPS = False + NORM2FN = { 'rms_norm': RMSNorm, 'layer_norm': nn.LayerNorm, @@ -78,7 +86,73 @@ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: return embeddings -class InternAttention(nn.Module): +class InternParallelAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f'embed_dim must be divisible by num_heads ' + f'(got `embed_dim`: {self.embed_dim} and `num_heads`:' + f' {self.num_heads}).') + + self.scale = self.head_dim**-0.5 + self.qkv = QKVParallelLinear( + self.embed_dim, + self.head_dim, + self.num_heads, + bias=config.qkv_bias, + quant_config=quant_config, + ) + + self.qk_normalization = config.qk_normalization + + if self.qk_normalization: + self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps) + self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps) + + self.proj = RowParallelLinear( + self.embed_dim, + self.embed_dim, + quant_config=quant_config, + ) + + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + + def forward(self, x): + B, N, C = x.shape + qkv, _ = self.qkv(x) + q, k, v = qkv.chunk(3, dim=-1) + + q = q.view(B, N, self.num_heads_per_partition, self.head_dim) + k = k.view(B, N, self.num_heads_per_partition, self.head_dim) + v = v.view(B, N, self.num_heads_per_partition, self.head_dim) + + if self.qk_normalization: + B_, N_, H_, D_ = q.shape + q = self.q_norm.forward_native(q.flatten(-2, + -1)).view(B_, N_, H_, D_) + k = self.k_norm.forward_native(k.flatten(-2, + -1)).view(B_, N_, H_, D_) + + x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale) + x = x.view(B, N, -1) + + x, _ = self.proj(x) + return x + + +class InternSdpaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: PretrainedConfig): @@ -108,19 +182,25 @@ def __init__(self, config: PretrainedConfig): def forward(self, x): B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, - C // self.num_heads).permute(2, 0, 3, 1, 4) - q, k, v = qkv.unbind(0) + qkv = self.qkv(x) + q, k, v = qkv.chunk(3, dim=-1) + + q = q.view(B, N, self.num_heads, self.head_dim) + k = k.view(B, N, self.num_heads, self.head_dim) + v = v.view(B, N, self.num_heads, self.head_dim) if self.qk_normalization: - B_, H_, N_, D_ = q.shape - q = self.q_norm.forward_native(q.transpose(1, 2).flatten( - -2, -1)).view(B_, N_, H_, D_).transpose(1, 2) - k = self.k_norm.forward_native(k.transpose(1, 2).flatten( - -2, -1)).view(B_, N_, H_, D_).transpose(1, 2) + B_, N_, H_, D_ = q.shape + q = self.q_norm.forward_native(q.flatten(-2, + -1)).view(B_, N_, H_, D_) + k = self.k_norm.forward_native(k.flatten(-2, + -1)).view(B_, N_, H_, D_) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) x = F.scaled_dot_product_attention(q, k, v, scale=self.scale) - x = x.transpose(1, 2).reshape(B, N, C) + x = x.transpose(1, 2).view(B, N, -1) x = self.proj(x) return x @@ -161,7 +241,14 @@ def __init__(self, self.intermediate_size = config.intermediate_size self.norm_type = config.norm_type - self.attn = InternAttention(config) + # fallback to sdpa attention if tp unavailable + tp_size = get_tensor_model_parallel_world_size() + num_heads = config.num_attention_heads + if USE_XFORMERS_OPS and num_heads % tp_size == 0: + self.attn = InternParallelAttention(config, + quant_config=quant_config) + else: + self.attn = InternSdpaAttention(config) self.mlp = InternMLP(config, quant_config=quant_config) self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 499cdb43fc8b2..11a8431a5e7f7 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- -from typing import Any, Dict, Iterable, List, Optional, Tuple +from functools import partial +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -7,7 +8,10 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -17,12 +21,15 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors + +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class InternLM2MLP(nn.Module): @@ -70,20 +77,21 @@ def __init__( ) -> None: super().__init__() self.hidden_size = hidden_size - tp_size = get_tensor_model_parallel_world_size() + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() self.total_num_heads = num_heads - assert self.total_num_heads % tp_size == 0 - self.num_heads = self.total_num_heads // tp_size + assert self.total_num_heads % self.tp_size == 0 + self.num_heads = self.total_num_heads // self.tp_size self.total_num_kv_heads = num_kv_heads - if self.total_num_kv_heads >= tp_size: + if self.total_num_kv_heads >= self.tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 + assert self.total_num_kv_heads % self.tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + assert self.tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -122,11 +130,27 @@ def __init__( quant_config=quant_config) def split_qkv(self, qkv: torch.Tensor): - qkv = qkv.view(-1, self.num_kv_heads, self.key_value_groups + 2, 128) - q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=2) - q = q.reshape(-1, self.q_size) - k = k.reshape(-1, self.kv_size) - v = v.reshape(-1, self.kv_size) + seq_len = qkv.shape[0] + if self.tp_size > 1: + qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size + qkv = tensor_model_parallel_all_gather(qkv) + qkv = torch.split(qkv, qkv_map, dim=-1) + qkv = qkv[::3] + qkv[1::3] + qkv[2::3] + qkv = torch.cat(qkv, dim=-1) + + qkv = qkv.view(seq_len, self.total_num_kv_heads, + self.key_value_groups + 2, self.head_dim) + q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2) + q = q.reshape(seq_len, self.q_size * self.tp_size) + k = k.reshape(seq_len, self.kv_size * self.tp_size) + v = v.reshape(seq_len, self.kv_size * self.tp_size) + + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + v = splitter(v)[self.tp_rank] return q, k, v def forward( @@ -213,6 +237,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -222,11 +247,15 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - InternLMDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: InternLMDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.tok_embeddings(input_ids) @@ -239,21 +268,31 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: IntermediateTensors = None, inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.tok_embeddings(input_ids) + residual = None else: - hidden_states = self.tok_embeddings(input_ids) - residual = None - for i in range(len(self.layers)): + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -277,6 +316,8 @@ def __init__( self.output.weight = self.model.tok_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -287,7 +328,7 @@ def forward( intermediate_tensors: IntermediateTensors, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -324,6 +365,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -332,6 +375,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 7f213287f33b4..0cf63d9e1fb22 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -5,6 +5,7 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- import itertools +import re from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -18,18 +19,20 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.intern_vit import InternVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors +from vllm.utils import is_list_of from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_num_patches) from .interfaces import SupportsMultiModal -from .utils import (filter_weights, init_vllm_registered_model, +from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, merge_multimodal_embeddings) IMG_START = '' @@ -42,19 +45,17 @@ class InternVLImagePixelInputs(TypedDict): type: Literal["pixel_values"] - data: Union[torch.Tensor, List[torch.Tensor]] + data: torch.Tensor """ - Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` - - Note that `num_patches` may be different for each batch, in which case - the data is passed as a list instead of a batched tensor. + Shape: + `(batch_size * num_images * (1 + num_patches), num_channels, height, width)` """ class InternVLImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] - data: Union[torch.Tensor, List[torch.Tensor]] - """Shape: `(batch_size, image_feature_size, hidden_size)` + data: torch.Tensor + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ @@ -96,8 +97,8 @@ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int, - max_num: int, - image_size: int) -> Tuple[int, int, int]: + max_num: int, image_size: int, + use_thumbnail: bool) -> Tuple[int, int, int]: aspect_ratio = orig_width / orig_height # calculate the existing image aspect ratio @@ -115,17 +116,26 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int, target_width = image_size * target_aspect_ratio[0] target_height = image_size * target_aspect_ratio[1] blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + # add thumbnail image if num_blocks > 1 + if use_thumbnail and blocks > 1: + blocks += 1 return blocks, target_width, target_height # adapted from https://huggingface.co/OpenGVLab/InternVL2-1B def dynamic_preprocess(image: Image.Image, min_num: int, max_num: int, image_size: int, - use_thumbnail: int) -> List[Image.Image]: + use_thumbnail: bool) -> List[Image.Image]: orig_width, orig_height = image.size + # calculate the number of blocks without thumbnail blocks, target_width, target_height = calculate_num_blocks( - orig_width, orig_height, min_num, max_num, image_size) + orig_width, + orig_height, + min_num, + max_num, + image_size, + use_thumbnail=False) # resize the image resized_img = image.resize((target_width, target_height)) processed_images = [] @@ -198,19 +208,25 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): downsample_ratio) image_data = multi_modal_data["image"] + min_num = hf_config.min_dynamic_patch + max_num = hf_config.max_dynamic_patch + use_thumbnail = hf_config.use_thumbnail if isinstance(image_data, Image.Image): width, height = image_data.size - min_num = hf_config.min_dynamic_patch - max_num = hf_config.max_dynamic_patch num_blocks, _, _ = calculate_num_blocks(width, height, min_num, - max_num, image_size) - # add thumbnail image if num_blocks > 1 - if hf_config.use_thumbnail and num_blocks > 1: - num_blocks += 1 - image_feature_size = num_blocks * num_patches - + max_num, image_size, + use_thumbnail) + image_feature_size = [num_blocks * num_patches] + elif is_list_of(image_data, Image.Image): + image_feature_size = [] + for image in image_data: + width, height = image.size + num_blocks, _, _ = calculate_num_blocks(width, height, min_num, + max_num, image_size, + use_thumbnail) + image_feature_size.append(num_blocks * num_patches) elif isinstance(image_data, torch.Tensor): - image_feature_size = image_data.shape[0] + num_images, image_feature_size, hidden_size = image_data.shape else: raise TypeError(f"Invalid image type: {type(image_data)}") @@ -221,8 +237,14 @@ def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs): prompt_token_ids = llm_inputs["prompt_token_ids"] if prompt is None: prompt = tokenizer.decode(prompt_token_ids) - image_prompt = IMG_START + IMG_CONTEXT * image_feature_size + IMG_END - new_prompt = prompt.replace('', image_prompt, 1) + + new_prompt = prompt + image_idx = sorted(map(int, re.findall(r"Image-(\d+): \n", prompt))) + for idx, feature_size in enumerate(image_feature_size, start=1): + image_prompt = IMG_START + IMG_CONTEXT * feature_size + IMG_END + if not image_idx: + image_prompt = f"Image-{idx}: {image_prompt}" + new_prompt = new_prompt.replace('', image_prompt, 1) new_prompt_token_ids = tokenizer.encode(new_prompt) return LLMInputs(prompt=prompt, @@ -246,6 +268,15 @@ def input_mapper_for_internvl(ctx: InputContext, data: object): use_thumbnail=use_thumbnail) # Add an N dimension for number of images per prompt (currently 1). data = data.unsqueeze(0) + elif is_list_of(data, Image.Image): + data = [ + image_to_pixel_values(img, + image_size, + min_num, + max_num, + use_thumbnail=use_thumbnail) for img in data + ] + data = torch.stack(data) model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) @@ -342,6 +373,8 @@ def __init__(self, nn.Linear(llm_hidden_size, llm_hidden_size)) self.img_context_token_id = None + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() @@ -357,7 +390,7 @@ def pixel_shuffle(self, x, scale_factor=0.5): x = x.permute(0, 2, 1, 3).contiguous() return x - def extract_feature(self, pixel_values): + def extract_feature(self, pixel_values: torch.Tensor) -> torch.Tensor: vit_embeds = self.vision_model(pixel_values=pixel_values) vit_embeds = vit_embeds[:, 1:, :] @@ -370,17 +403,7 @@ def extract_feature(self, pixel_values): vit_embeds = self.mlp1(vit_embeds) return vit_embeds - def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: - if list(data.shape[1:]) != [2]: - raise ValueError( - f"The expected image sizes shape is batch dimension plus " - f"{[2]}. You supplied {data.shape}.") - - return data - - def _validate_pixel_values( - self, data: Union[torch.Tensor, List[torch.Tensor]] - ) -> Union[torch.Tensor, List[torch.Tensor]]: + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) @@ -389,10 +412,11 @@ def _validate_shape(d: torch.Tensor): actual_dims = tuple(d.shape) if actual_dims != expected_dims: - expected_expr = ("num_patches", *map(str, expected_dims)) + expected_expr = str(expected_dims) raise ValueError( - "The expected shape of pixel values in each batch element " - f"is {expected_expr}. You supplied {tuple(d.shape)}.") + "The expected shape of pixel values per image per batch " + f" per patch is {expected_expr}. " + f"You supplied {tuple(d.shape)}.") for d in data: _validate_shape(d) @@ -413,12 +437,9 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") - # Flatten the B and N dimensions - image_embeds = image_embeds.flatten(0, 2) - return InternVLImageEmbeddingInputs( type="image_embeds", - data=image_embeds, + data=flatten_bn(image_embeds), ) self.img_context_token_id = image_token_id[0] @@ -428,12 +449,10 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - # Flatten the B and N dimensions - pixel_values = pixel_values.flatten(0, 2) - return InternVLImagePixelInputs( type="pixel_values", - data=self._validate_pixel_values(pixel_values), + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True).flatten(0, 1)), ) raise AssertionError("This line should be unreachable.") @@ -476,7 +495,7 @@ def forward( positions, kv_caches, attn_metadata, - None, + intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index a550f7e6c97a1..b0fbb7e9829e0 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -35,12 +35,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import JAISConfig from .utils import is_pp_missing_parameter, make_layers diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index caeda4e42d8a0..29dd09afac5ad 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -4,9 +4,6 @@ from typing import Dict, Iterable, List, Optional, Tuple import torch -from causal_conv1d import causal_conv1d_fn, causal_conv1d_update -from mamba_ssm.ops.selective_scan_interface import selective_scan_fn -from mamba_ssm.ops.triton.selective_state_update import selective_state_update from torch import nn from torch.nn.parameter import Parameter from transformers import JambaConfig @@ -24,19 +21,25 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.mamba.ops.causal_conv1d import ( + causal_conv1d_fn, causal_conv1d_update) +from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( + selective_scan_fn, selective_state_update) from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import HasInnerState from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _get_graph_batch_size) +from .interfaces import SupportsLoRA + KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -161,7 +164,7 @@ def mamba_forward(self, (self.conv_kernel_size - hidden_states.shape[-1], 0)) cache_params.conv_state.copy_(conv_states) - hidden_states = causal_conv1d_fn( + hidden_states, _ = causal_conv1d_fn( hidden_states, conv_weights, self.conv1d.bias, @@ -538,7 +541,7 @@ def forward( return hidden_states -class JambaForCausalLM(nn.Module, HasInnerState): +class JambaForCausalLM(nn.Module, HasInnerState, SupportsLoRA): packed_modules_mapping = { "qkv_proj": [ "q_proj", diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 0c67a9b8e198b..5ff31e3833ec9 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -42,13 +42,13 @@ from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import is_hip from .interfaces import SupportsLoRA @@ -375,6 +375,25 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): "gate_proj": ("gate_up_proj", 0), "up_proj": ("gate_up_proj", 1), } + # Mistral/Llama models can also be loaded with --load-format mistral + # from consolidated.safetensors checkpoints + mistral_mapping = { + "layers": "model.layers", + "attention": "self_attn", + "wq": "q_proj", + "wk": "k_proj", + "wv": "v_proj", + "wo": "o_proj", + "attention_norm": "input_layernorm", + "feed_forward": "mlp", + "w1": "gate_proj", + "w2": "down_proj", + "w3": "up_proj", + "ffn_norm": "post_attention_layernorm", + "tok_embeddings": "model.embed_tokens", + "output": "lm_head", + "norm": "model.norm" + } def __init__( self, @@ -472,6 +491,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: + name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight) + if "rotary_emb.inv_freq" in name: continue if ("rotary_emb.cos_cached" in name @@ -549,3 +570,33 @@ def load_kv_cache_scales(self, quantization_param_path: str) -> None: else: raise RuntimeError("Self attention has no KV cache scaling " "factor attribute!") + + # This function is used to remap the mistral format as + # used by Mistral and Llama <=2 + def maybe_remap_mistral( + self, name: str, + loaded_weight: torch.Tensor) -> Tuple[str, torch.Tensor]: + + def permute(w, n_heads): + attn_in = self.config.head_dim * n_heads + attn_out = self.config.hidden_size + + return w.view(n_heads, attn_in // n_heads // 2, 2, + attn_out).transpose(1, 2).reshape(attn_in, attn_out) + + mapping = self.mistral_mapping + modules = name.split(".") + + # rotary embeds should be sliced + if "wk" in modules: + loaded_weight = permute(loaded_weight, + self.config.num_key_value_heads) + elif "wq" in modules: + loaded_weight = permute(loaded_weight, + self.config.num_attention_heads) + + for item in modules: + if item in mapping and mapping[item] not in name: + name = name.replace(item, mapping[item]) + + return name, loaded_weight diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 03a0abf1db481..7a6c991fb133a 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn +from PIL import Image from transformers import CLIPVisionConfig, LlavaConfig, SiglipVisionConfig from vllm.attention import AttentionMetadata @@ -11,10 +12,12 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors +from vllm.utils import is_list_of from .clip import (CLIPVisionModel, dummy_image_for_clip, dummy_seq_data_for_clip, get_max_clip_image_tokens, @@ -23,20 +26,20 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens, input_processor_for_siglip) -from .utils import (filter_weights, init_vllm_registered_model, +from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, merge_multimodal_embeddings) class LlavaImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor - """Shape: `(batch_size, num_channels, height, width)`""" + """Shape: `(batch_size * num_images, num_channels, height, width)`""" class LlavaImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] data: torch.Tensor - """Shape: `(batch_size, image_feature_size, hidden_size)` + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ @@ -132,7 +135,18 @@ def input_processor_for_llava(ctx: InputContext, llm_inputs: LLMInputs): hf_config = ctx.get_hf_config(LlavaConfig) vision_config = hf_config.vision_config - image_feature_size = get_max_llava_image_tokens(ctx) + image_data = multi_modal_data["image"] + if isinstance(image_data, Image.Image): + image_feature_size = get_max_llava_image_tokens(ctx) + elif is_list_of(image_data, Image.Image): + image_feature_size = [get_max_llava_image_tokens(ctx) + ] * len(image_data) + elif isinstance(image_data, torch.Tensor): + num_images, image_feature_size, hidden_size = image_data.shape + elif is_list_of(image_data, torch.Tensor): + image_feature_size = [item.shape[1] for item in image_data] + else: + raise TypeError(f"Invalid image type: {type(image_data)}") if isinstance(vision_config, CLIPVisionConfig): return input_processor_for_clip( @@ -229,29 +243,24 @@ def _parse_and_validate_image_input( return None if pixel_values is not None: - if not isinstance(pixel_values, torch.Tensor): + if not isinstance(pixel_values, (torch.Tensor, list)): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - # Remove the N dimension until multiple images are supported. - pixel_values = pixel_values.squeeze(1) - return LlavaImagePixelInputs( type="pixel_values", - data=self._validate_pixel_values(pixel_values), + data=self._validate_pixel_values( + flatten_bn(pixel_values, concat=True)), ) if image_embeds is not None: - if not isinstance(image_embeds, torch.Tensor): + if not isinstance(image_embeds, (torch.Tensor, list)): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") - # Remove the N dimension until multiple images are supported. - image_embeds = image_embeds.squeeze(1) - return LlavaImageEmbeddingInputs( type="image_embeds", - data=image_embeds, + data=flatten_bn(image_embeds, concat=True), ) raise AssertionError("This line should be unreachable.") diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 3a87242954114..c6bd46dd7eda9 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -15,10 +15,11 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from .clip import (CLIPVisionModel, dummy_image_for_clip, @@ -29,7 +30,7 @@ from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_siglip_image_feature_size, get_siglip_patch_grid_length, input_processor_for_siglip) -from .utils import (filter_weights, init_vllm_registered_model, +from .utils import (filter_weights, flatten_bn, init_vllm_registered_model, merge_multimodal_embeddings) logger = init_logger(__name__) @@ -47,15 +48,16 @@ class LlavaNextImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: Union[torch.Tensor, List[torch.Tensor]] """ - Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` + Shape: + `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - Note that `num_patches` may be different for each batch, in which case - the data is passed as a list instead of a batched tensor. + Note that `num_patches` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. """ image_sizes: NotRequired[torch.Tensor] """ - Shape: `(batch_size, 2)` + Shape: `(batch_size * num_images, 2)` This should be in `(height, width)` format. """ @@ -64,7 +66,7 @@ class LlavaNextImagePixelInputs(TypedDict): class LlavaNextImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] data: torch.Tensor - """Shape: `(batch_size, image_feature_size, hidden_size)` + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ @@ -232,7 +234,9 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): for img in image_data ] elif isinstance(image_data, torch.Tensor): - image_feature_size = image_data.shape[0] + num_images, image_feature_size, hidden_size = image_data.shape + elif is_list_of(image_data, torch.Tensor): + image_feature_size = [item.shape[1] for item in image_data] else: raise TypeError(f"Invalid image type: {type(image_data)}") @@ -315,10 +319,19 @@ def __init__(self, torch.empty(config.text_config.hidden_size)) def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: - if list(data.shape[1:]) != [2]: - raise ValueError( - f"The expected image sizes shape is batch dimension plus " - f"{[2]}. You supplied {data.shape}.") + expected_dims = (2, ) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + f"The expected shape of image sizes per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) return data @@ -335,7 +348,7 @@ def _validate_shape(d: torch.Tensor): if actual_dims != expected_dims: expected_expr = ("num_patches", *map(str, expected_dims)) raise ValueError( - "The expected shape of pixel values in each batch element " + "The expected shape of pixel values per image per batch " f"is {expected_expr}. You supplied {tuple(d.shape)}.") for d in data: @@ -357,22 +370,15 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - if not isinstance(image_sizes, torch.Tensor): + if not isinstance(image_sizes, (torch.Tensor, list)): raise ValueError("Incorrect type of image sizes. " f"Got type: {type(image_sizes)}") - # Remove the N dimension until multiple images are supported. - if isinstance(pixel_values, torch.Tensor): - pixel_values = pixel_values.squeeze(1) - else: - pixel_values = [t.squeeze(0) for t in pixel_values] - - image_sizes = image_sizes.squeeze(1) - return LlavaNextImagePixelInputs( type="pixel_values", - data=self._validate_pixel_values(pixel_values), - image_sizes=self._validate_image_sizes(image_sizes), + data=self._validate_pixel_values(flatten_bn(pixel_values)), + image_sizes=self._validate_image_sizes( + flatten_bn(image_sizes, concat=True)), ) if image_embeds is not None: @@ -380,12 +386,9 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of image embeds. " f"Got type: {type(image_embeds)}") - # Remove the N dimension until multiple images are supported. - image_embeds = image_embeds.squeeze(1) - return LlavaNextImageEmbeddingInputs( type="image_embeds", - data=image_embeds, + data=flatten_bn(image_embeds), ) raise AssertionError("This line should be unreachable.") diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index 55d42952cd0cc..619a5cd00d6b6 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -4,11 +4,11 @@ import torch.nn as nn from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.medusa import MedusaConfig diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index ff42bdefe0269..a135118bc748e 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -44,13 +44,13 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 6a3d5422e0ce4..f8be9490ee55d 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -26,11 +26,9 @@ from array import array from functools import partial from typing import (Any, Callable, Iterable, List, Mapping, Optional, Tuple, - TypedDict, Union) + TypedDict) -import numpy as np import torch -import torch.nn.functional as F import torch.types from PIL import Image from torch import nn @@ -44,7 +42,9 @@ from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.resampler import (Resampler2, + get_2d_sincos_pos_embed) +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -57,7 +57,7 @@ from vllm.multimodal.image import cached_get_image_processor from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, - SamplerOutput, SequenceData) + SequenceData) from .idefics2_vision_model import Idefics2VisionTransformer @@ -98,101 +98,6 @@ class MiniCPMVImagePixelInputs(TypedDict): DEFAULT_LN = partial(nn.LayerNorm, eps=1e-6) -def get_abs_pos(abs_pos: torch.Tensor, tgt_size: torch.Tensor): - # abs_pos: L, C - # tgt_size: (H, W) - # return: M, C - src_size = int(math.sqrt(abs_pos.size(0))) - # tgt_size = int(math.sqrt(tgt_size)) - dtype = abs_pos.dtype - - return (F.interpolate( - abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), - size=(tgt_size[0], tgt_size[1]), - mode="bicubic", - align_corners=False, - ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)) - - -# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 -def get_2d_sincos_pos_embed( - embed_dim: int, - grid_size: Union[int, Tuple[int, int]], - cls_token: bool = False, - version: Tuple[int, int] = (2, 0), -): - """ - grid_size: int of the grid height and width - return: - pos_embed: [grid_size*grid_size, embed_dim] or - [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) - """ - if isinstance(grid_size, int): - grid_h_size, grid_w_size = grid_size, grid_size - else: - grid_h_size, grid_w_size = grid_size[0], grid_size[1] - - grid_h = np.arange(grid_h_size, dtype=np.float32) - grid_w = np.arange(grid_w_size, dtype=np.float32) - grid = np.meshgrid(grid_w, grid_h) # here w goes first - grid = np.stack(grid, axis=0) - - if version == (2, 0): - grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) - if cls_token: - pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], - axis=0) - else: - pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) - return pos_embed - - -def get_2d_sincos_pos_embed_from_grid(embed_dim: int, - grid: np.ndarray, - version: Tuple[int, int] = (2, 0)): - assert embed_dim % 2 == 0 - - # use half of dimensions to encode grid_h - emb_h = get_1d_sincos_pos_embed_from_grid( - embed_dim // 2, grid[0], version) # (H*W, D/2) or (H, W, D/2) - emb_w = get_1d_sincos_pos_embed_from_grid( - embed_dim // 2, grid[1], version) # (H*W, D/2) or (H, W, D/2) - - if version == (2, 0): - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) - else: - emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D) - return emb - - -def get_1d_sincos_pos_embed_from_grid(embed_dim: int, - pos: np.ndarray, - version: Tuple[int, int] = (2, 0)): - """ - embed_dim: output dimension for each position - pos: a list of positions to be encoded: size (M,) / (H, W) - out: (M, D) / (H, W, D) - """ - assert embed_dim % 2 == 0 - omega = np.arange(embed_dim // 2, dtype=np.float32) - omega /= embed_dim / 2.0 - omega = 1.0 / 10000**omega # (D/2,) - - if version == (2, 0): - pos = pos.reshape(-1) # (M,) - out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) - emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) - else: - out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product - emb_sin = np.sin(out) # (H, W, D/2) - emb_cos = np.cos(out) # (H, W, D/2) - emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D) - return emb - - class BaseResampler(nn.Module): """ A 2D perceiver-resampler network with one cross attention layers by @@ -245,62 +150,6 @@ def _repeat(self, query, N: int): return query.unsqueeze(1).repeat(1, N, 1) -class Resampler2(BaseResampler): - - def __init__( - self, - grid_size: int, - embed_dim: int, - num_heads: int, - kv_dim: Optional[int] = None, - norm_layer: Callable[[int], nn.LayerNorm] = DEFAULT_LN, - adaptive: bool = False, - ) -> None: - super().__init__(grid_size**2, embed_dim, num_heads, kv_dim, - norm_layer) - - self.adaptive = adaptive - pos_embed_arr = get_2d_sincos_pos_embed(embed_dim, - grid_size, - version=(2, 0)) - self.pos_embed = nn.Parameter( - torch.from_numpy(pos_embed_arr).float()).requires_grad_(False) - - self.apply(self._init_weights) - - def forward( - self, - x: torch.Tensor, - tgt_sizes: torch.Tensor, - attn_mask: Optional[torch.Tensor] = None, - ): - if self.adaptive: - pos_embed_arr = get_2d_sincos_pos_embed(self.embed_dim, - tgt_sizes, - version=(2, 0)) - pos_embed = torch.from_numpy(pos_embed_arr).to(device=x.device, - dtype=x.dtype) - else: - pos_embed = get_abs_pos(self.pos_embed, tgt_sizes) - - x, _ = self.kv_proj(x) - x = self.ln_kv(x).permute(1, 0, 2) - - N = x.shape[1] - q = self.ln_q(self.query) - out = self.attn( - self._repeat(q, N) + self.pos_embed.unsqueeze(1), - x + pos_embed.unsqueeze(1), - x, - attn_mask=attn_mask, - )[0] - x = out.permute(1, 0, 2) - - x = self.ln_post(x) - x = x @ self.proj - return x - - class Resampler2_5(BaseResampler): def __init__( @@ -782,7 +631,8 @@ def init_resampler(self, embed_dim: int, vision_dim: int) -> nn.Module: num_heads=embed_dim // 128, grid_size=int(math.sqrt(self.config.query_num)), kv_dim=vision_dim, - adaptive=True, + adaptive=False, + do_post_projection=True, ) return resampler diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 413783ba4b259..10cbfcf6432b3 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -39,13 +39,13 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA from .utils import is_pp_missing_parameter, make_layers @@ -435,7 +435,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): @@ -454,6 +455,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, @@ -464,7 +468,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): break else: # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: + if ((name.endswith(".bias") or name.endswith("_bias")) + and name not in params_dict): continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 8bdd52b343175..68471f6ac77d1 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -42,12 +42,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class MixtralMLP(nn.Module): diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index 9b96ecb78a3c9..42ccd01298169 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -6,11 +6,10 @@ from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import MLPSpeculatorConfig SQRT2 = 2**0.5 diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 1a8e514a7ae83..0fcbf06e1a060 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -17,12 +17,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.mpt import MPTConfig diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index 7d92a1ffe55df..e9ff12de2094e 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -37,13 +37,13 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronConfig from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 8de124cd034dc..97749725dd132 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -38,12 +38,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class OlmoAttention(nn.Module): diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index c0d2d537e731f..88d2bcb9f0c9d 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -34,12 +34,12 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class OPTLearnedPositionalEmbedding(nn.Embedding): diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index fab35f0b882a7..b01ce87adfa46 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -21,12 +21,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class OrionMLP(nn.Module): diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 0700f0c29d708..5fd39b5e35be6 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -1,3 +1,4 @@ +import itertools from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -11,36 +12,32 @@ from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.gemma import GemmaModel +from vllm.model_executor.models.gemma import GemmaForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import cached_get_tokenizer -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsMultiModal from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens) -from .utils import merge_multimodal_embeddings +from .utils import filter_weights, merge_multimodal_embeddings logger = init_logger(__name__) -_KEYS_TO_MODIFY_MAPPING = { - "language_model.model": "language_model", -} - class PaliGemmaImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: torch.Tensor - """Shape: (batch_size, num_channels, height, width)""" + """Shape: `(batch_size * num_images, num_channels, height, width)`""" class PaliGemmaImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] data: torch.Tensor - """Shape: `(batch_size, image_feature_size, hidden_size)` + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ @@ -145,15 +142,14 @@ def __init__(self, self.config = config self.multimodal_config = multimodal_config - # TODO(ywang96): Port over SiglipVisionModel & TP self.vision_tower = SiglipVisionModel(config.vision_config) self.multi_modal_projector = PaliGemmaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, projection_dim=config.vision_config.projection_dim) self.quant_config = quant_config - self.language_model = GemmaModel(config.text_config, cache_config, - quant_config) + self.language_model = GemmaForCausalLM(config.text_config, + cache_config, quant_config) self.unpadded_vocab_size = config.text_config.vocab_size logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, @@ -253,7 +249,8 @@ def forward(self, vision_embeddings = vision_embeddings * (self.config.hidden_size** -0.5) - inputs_embeds = self.language_model.get_input_embeddings(input_ids) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, vision_embeddings, @@ -263,90 +260,47 @@ def forward(self, else: inputs_embeds = None - hidden_states = self.language_model(input_ids, - positions, - kv_caches, - attn_metadata, - None, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + None, + inputs_embeds=inputs_embeds) return hidden_states - # Copied from vllm/model_executor/models/gemma.py def compute_logits( self, hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.language_model.embed_tokens, - hidden_states, sampling_metadata) - return logits + return self.language_model.compute_logits(hidden_states, + sampling_metadata) - # Copied from vllm/model_executor/models/gemma.py def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens + return self.language_model.sample(logits, sampling_metadata) - # Adapted from vllm/model_executor/models/gemma.py def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("qkv_proj", "q_proj", "q"), - ("qkv_proj", "k_proj", "k"), - ("qkv_proj", "v_proj", "v"), - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - params_dict = dict(self.named_parameters()) - loaded_params = set() - for name, loaded_weight in weights: - for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - if key_to_modify in name: - name = name.replace(key_to_modify, new_key) - use_default_weight_loading = False - if "vision" in name: - if self.vision_tower is not None: - # We only do sharding for language model and - # not vision model for now. - use_default_weight_loading = True - else: - for (param_name, shard_name, - shard_id) in stacked_params_mapping: - if shard_name not in name: - continue - name = name.replace(shard_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # lm_head is not used in vllm as it is tied with - # embed_token. To prevent errors, skip loading - # lm_head.weight. - if "lm_head.weight" in name: - continue - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - use_default_weight_loading = True - - if use_default_weight_loading: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - loaded_params.add(name) - - unloaded_params = params_dict.keys() - loaded_params - if unloaded_params: - logger.warning( - "Some weights are not initialized from checkpoints: %s", - unloaded_params) + # prepare weight iterators for components + vit_weights, mlp_weights, llm_weights = itertools.tee(weights, 3) + + # load vision tower + vit_weights = filter_weights(vit_weights, "vision_tower") + self.vision_tower.load_weights(vit_weights) + + # load mlp projector + mlp_weights = filter_weights(mlp_weights, "multi_modal_projector") + mlp_params_dict = dict(self.multi_modal_projector.named_parameters()) + for name, loaded_weight in mlp_weights: + param = mlp_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load llm backbone + llm_weights = filter_weights(llm_weights, "language_model") + self.language_model.load_weights(llm_weights) diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 3300939c7b102..f8fc1cd8ef1f0 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -37,12 +37,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class PersimmonMLP(nn.Module): diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index f31b5162aac96..15c21cfa2d8a8 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -52,12 +52,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index df01bfa3d8e6e..afc6fe9844ad6 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -16,12 +16,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors def load_column_parallel_weight(param: torch.nn.Parameter, diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 61f1d73976379..6f17f571ccaea 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -31,7 +31,7 @@ from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.clip import CLIPVisionModel @@ -39,12 +39,12 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import is_list_of from .clip import dummy_image_for_clip, dummy_seq_data_for_clip from .interfaces import SupportsMultiModal -from .utils import merge_multimodal_embeddings +from .utils import flatten_bn, merge_multimodal_embeddings logger = init_logger(__name__) @@ -71,19 +71,37 @@ projection_dim=768) +def _init_img_processor(hf_config: PretrainedConfig): + clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG + layer_idx = hf_config.img_processor.get('layer_idx', -2) + + # Initialize the CLIP only up to the required feature layer + if layer_idx < 0: + num_hidden_layers = clip_config.num_hidden_layers + \ + layer_idx + 1 + else: + num_hidden_layers = layer_idx + 1 + + img_processor = CLIPVisionModel( + clip_config, num_hidden_layers_override=num_hidden_layers) + + return img_processor + + class Phi3VImagePixelInputs(TypedDict): type: Literal["pixel_values"] data: Union[torch.Tensor, List[torch.Tensor]] """ - Shape: `(batch_size, 1 + num_patches, num_channels, height, width)` + Shape: + `(batch_size * num_images, 1 + num_patches, num_channels, height, width)` - Note that `num_patches` may be different for each batch, in which case - the data is passed as a list instead of a batched tensor. + Note that `num_patches` may be different per batch and image, + in which case the data is passed as a list instead of a batched tensor. """ image_sizes: torch.Tensor """ - Shape: `(batch_size, 2)` + Shape: `(batch_size * num_images, 2)` This should be in `(height, width)` format. """ @@ -92,7 +110,7 @@ class Phi3VImagePixelInputs(TypedDict): class Phi3VImageEmbeddingInputs(TypedDict): type: Literal["image_embeds"] data: Union[torch.Tensor, List[torch.Tensor]] - """Shape: `(batch_size, image_feature_size, hidden_size)` + """Shape: `(batch_size * num_images, image_feature_size, hidden_size)` `hidden_size` must match the hidden size of language model backbone. """ @@ -138,18 +156,8 @@ def __init__(self, config: PretrainedConfig) -> None: hidden_size = config.n_embd if hasattr( config, 'n_embd') else config.hidden_size - clip_config = CLIP_VIT_LARGE_PATCH14_336_CONFIG - self.layer_idx = config.img_processor.get('layer_idx', -2) - - # Initialize the CLIP only up to the required feature layer - if self.layer_idx < 0: - num_hidden_layers = clip_config.num_hidden_layers + \ - self.layer_idx + 1 - else: - num_hidden_layers = self.layer_idx + 1 + self.img_processor = _init_img_processor(config) - self.img_processor = CLIPVisionModel( - clip_config, num_hidden_layers_override=num_hidden_layers) image_dim_out = config.img_processor['image_dim_out'] self.num_img_tokens = config.img_processor['num_img_tokens'] @@ -416,7 +424,9 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): input_width=w, input_height=h)) elif isinstance(image_data, torch.Tensor): - image_feature_size = image_data.shape[0] + num_images, image_feature_size, hidden_size = image_data.shape + elif is_list_of(image_data, torch.Tensor): + image_feature_size = [item.shape[1] for item in image_data] else: raise TypeError(f"Invalid image type: {type(image_data)}") @@ -511,10 +521,19 @@ def __init__(self, self.sampler = Sampler() def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: - if list(data.shape[1:]) != [2]: - raise ValueError( - f"The expected shape of image sizes is batch dimension plus " - f"{[2]}. You supplied {tuple(data.shape)}.") + expected_dims = (2, ) + + def _validate_shape(d: torch.Tensor): + actual_dims = tuple(d.shape) + + if actual_dims != expected_dims: + expected_expr = str(expected_dims) + raise ValueError( + f"The expected shape of image sizes per image per batch " + f"is {expected_expr}. You supplied {tuple(d.shape)}.") + + for d in data: + _validate_shape(d) return data @@ -531,7 +550,7 @@ def _validate_shape(d: torch.Tensor): if actual_dims != expected_dims: expected_expr = ("num_patches", *map(str, expected_dims)) raise ValueError( - "The expected shape of pixel values in each batch element " + "The expected shape of pixel values per image per batch " f"is {expected_expr}. You supplied {tuple(d.shape)}.") for d in data: @@ -556,30 +575,24 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") - if not isinstance(image_sizes, torch.Tensor): + if not isinstance(image_sizes, (torch.Tensor, list)): raise ValueError("Incorrect type of image sizes. " f"Got type: {type(image_sizes)}") - # Merge the B and N dimensions. - if isinstance(pixel_values, torch.Tensor): - pixel_values = pixel_values.flatten(0, 1) - else: - pixel_values = torch.cat(pixel_values) - - image_sizes = image_sizes.flatten(0, 1) - return Phi3VImagePixelInputs( type="pixel_values", - data=self._validate_pixel_values(pixel_values), - image_sizes=self._validate_image_sizes(image_sizes)) + data=self._validate_pixel_values(flatten_bn(pixel_values)), + image_sizes=self._validate_image_sizes( + flatten_bn(image_sizes, concat=True))) if image_embeds is not None: if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") + return Phi3VImageEmbeddingInputs( type="image_embeds", - data=image_embeds, + data=flatten_bn(image_embeds), ) raise AssertionError("This line should be unreachable.") @@ -652,23 +665,27 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] + + # TODO(ChristopherCho): This is a temporary fix to load + # the vision weights with CLIPVisionModel.load_weights() + vision_weights = [] params_dict = dict(self.named_parameters()) for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - # post_layernorm is not needed in CLIPVisionModel - if "vision_model.post_layernorm" in name: + # Skip loading the img_processor weights since they are + # loaded separately. + if "vision_embed_tokens.img_processor" in name: + vision_weights.append((name, loaded_weight)) continue + for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): if key_to_modify in name: name = name.replace(key_to_modify, new_key) for (param_name, weight_name, shard_id) in stacked_params_mapping: - # We only do sharding for language model - # and not vision model for now. - if "vision_embed_tokens" in name and self.vision_embed_tokens: - continue if weight_name not in name: continue + param = params_dict[name.replace(weight_name, param_name)] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -682,3 +699,11 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + # We use regex to extract the sub-module name + # from "model.vision_embed_tokens.img_processor.*" + vision_weights = [ + (re.search(r"vision_embed_tokens\.img_processor\.(.*)", + n).group(1), w) for n, w in vision_weights + ] + self.vision_embed_tokens.img_processor.load_weights(vision_weights) diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py new file mode 100644 index 0000000000000..25bc0590c745c --- /dev/null +++ b/vllm/model_executor/models/phimoe.py @@ -0,0 +1,620 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only PhiMoE model.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .interfaces import SupportsLoRA + + +class PhiMoEConfig(PretrainedConfig): + + model_type = "phimoe" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=14336, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=8, + hidden_act="silu", + max_position_embeddings=4096 * 32, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=1e6, + sliding_window=None, + attention_dropout=0.0, + num_experts_per_tok=2, + num_local_experts=16, + output_router_logits=False, + router_aux_loss_coef=0.001, + router_jitter_noise=0.0, + attention_bias=False, + lm_head_bias=False, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.sliding_window = sliding_window + self.attention_bias = attention_bias + self.lm_head_bias = lm_head_bias + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.attention_dropout = attention_dropout + + self.num_experts_per_tok = num_experts_per_tok + self.num_local_experts = num_local_experts + self.output_router_logits = output_router_logits + self.router_aux_loss_coef = router_aux_loss_coef + self.router_jitter_noise = router_jitter_noise + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + +class mp(torch.autograd.Function): + + @staticmethod + def forward( + ctx, + scores: torch.Tensor, + multiplier: torch.Tensor, + selected_experts: torch.Tensor, + masked_gates: torch.Tensor, + mask_for_one: torch.Tensor, + ): + ctx.save_for_backward(multiplier, selected_experts, masked_gates) + return multiplier * mask_for_one + + @staticmethod + def backward( + ctx, + grad_at_output: torch.Tensor, + ): + multiplier, selected_experts, masked_gates = ctx.saved_tensors + + grad_at_output = grad_at_output * multiplier + + grad_at_scores_expaned = masked_gates * grad_at_output.mul(-1) + grad_at_scores_expaned.scatter_add_( + dim=-1, + index=selected_experts, + src=grad_at_output, + ) + + return ( + grad_at_scores_expaned, + None, + None, + None, + None, + ) + + +def sparsemixer(scores, jitter_eps=0.01): + ################ first expert ################ + + with torch.no_grad(): + # compute mask for sparsity + mask_logits_threshold, max_ind = scores.max(dim=-1, keepdim=True) + factor = scores.abs().clamp(min=mask_logits_threshold) + mask_logits_threshold = ( + (mask_logits_threshold - scores) / factor) > (2 * jitter_eps) + + # apply mask + masked_gates = scores.masked_fill(mask_logits_threshold, float("-inf")) + selected_experts = max_ind + + # compute scores for gradients + masked_gates = torch.softmax(masked_gates, dim=-1) + multiplier_o = masked_gates.gather(dim=-1, index=selected_experts) + + multiplier = multiplier_o + + # masked out first expert + masked_scores = torch.scatter( + scores, + -1, + selected_experts, + float("-inf"), + ) + with torch.no_grad(): + # compute mask for sparsity + mask_logits_threshold, max_ind = masked_scores.max(dim=-1, + keepdim=True) + factor = scores.abs().clamp(min=mask_logits_threshold) + mask_logits_threshold = ( + (mask_logits_threshold - scores) / factor) > (2 * jitter_eps) + + # apply mask + masked_gates_top2 = masked_scores.masked_fill(mask_logits_threshold, + float("-inf")) + selected_experts_top2 = max_ind + # compute scores for gradients + masked_gates_top2 = torch.softmax(masked_gates_top2, dim=-1) + multiplier_top2 = masked_gates_top2.gather(dim=-1, + index=selected_experts_top2) + + multiplier = torch.concat((multiplier, multiplier_top2), dim=-1) + selected_experts = torch.concat((selected_experts, selected_experts_top2), + dim=-1) + + return ( + multiplier, + selected_experts, + ) + + +def phimoe_routing_function( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + assert topk == 2, "Only top-2 routing is supported" + assert renormalize is False, "Renormalization is not supported" + + topk_weights, topk_ids = sparsemixer(gating_output) + return topk_weights, topk_ids + + +class PhiMoE(nn.Module): + """A tensor-parallel MoE implementation for PhiMoE that shards each expert + across all ranks. + + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__( + self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + ): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear( + hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + ) + + self.experts = FusedMoE( + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=False, + quant_config=quant_config, + tp_size=tp_size, + custom_routing_function=phimoe_routing_function) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class PhiMoEAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[dict] = None, + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=None, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=True, + quant_config=None, + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + rope_scaling=self.rope_scaling, + ) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class PhiMoEDecoderLayer(nn.Module): + + def __init__( + self, + config: PhiMoEConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = PhiMoEAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=config.rope_scaling, + ) + self.block_sparse_moe = PhiMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + ) + self.input_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps, + elementwise_affine=True) + self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps, + elementwise_affine=True) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: + residual = hidden_states + + # Self Attention + hidden_states = self.input_layernorm(hidden_states) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = hidden_states + residual + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe(hidden_states) + + hidden_states = hidden_states + residual + return hidden_states, residual + + +class PhiMoEModel(nn.Module): + + def __init__( + self, + config: PhiMoEConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + lora_vocab = ((lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0) + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.layers = nn.ModuleList([ + PhiMoEDecoderLayer(config, cache_config, quant_config=quant_config) + for _ in range(config.num_hidden_layers) + ]) + self.norm = nn.LayerNorm(config.hidden_size, + eps=config.rms_norm_eps, + elementwise_affine=True) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + hidden_states = self.embed_tokens(input_ids) + residual = None + for i in range(len(self.layers)): + layer = self.layers[i] + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i], attn_metadata, + residual) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class PhiMoEForCausalLM(nn.Module, SupportsLoRA): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: PhiMoEConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = PhiMoEModel(config, + cache_config, + quant_config, + lora_config=lora_config) + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size), + quant_config=None, + bias=True, + ) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata) + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + + expert_params_mapping = FusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="w1", + ckpt_down_proj_name="w2", + ckpt_up_proj_name="w3", + num_experts=self.config.num_local_experts) + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + weight_name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index b7d017d5f3ea6..a726ec10984c0 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -4,36 +4,402 @@ # Copyright (c) Alibaba Cloud. # LICENSE: https://huggingface.co/Qwen/Qwen-7B/blob/main/LICENSE """Inference-only QWen model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple +import math +import re +from array import array +from functools import partial +from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping, + Optional, Tuple, TypedDict, Union) + +import numpy as np import torch +from PIL import Image from torch import nn +from torchvision import transforms +from torchvision.transforms import InterpolationMode from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata -from vllm.config import CacheConfig +from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size -from vllm.model_executor.layers.activation import SiluAndMul +from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs +from vllm.logger import init_logger +from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput -from vllm.utils import print_warning_once +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.utils import cached_get_tokenizer +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SequenceData) + +from .utils import flatten_bn, is_pp_missing_parameter, make_layers + +logger = init_logger(__name__) + +# NOTE: Qwen models have a few other special tags, e.g., ref, bbox, quad; +# for the time being, these tags are not considered as special at encoding +# time. This may change as VLLMs multimodal API changes in the future. +IMG_START = "" +IMG_END = "" +IMG_PAD = "" +# Image context is fixed at 256 for all images +MAX_QWEN_IMG_TOKENS = 256 +# Image normalization params +CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) +CLIP_STD = (0.26862954, 0.26130258, 0.27577711) + + +class QwenImagePixelInputs(TypedDict): + type: Literal["pixel_values"] + data: torch.Tensor + """ + Shape: `(batch_size * num_images, 3, image_size, image_size)` + + Note that image_size is the value in the vision config to which we resize + the image to in the normalization transform. Currently multi-image support + can only be leveraged by passing image embeddings directly. + """ + + +class QwenImageEmbeddingInputs(TypedDict): + type: Literal["image_embeds"] + data: torch.Tensor + """Shape: `(batch_size * num_images, 256, hidden_size)` + + `hidden_size` must match the hidden size of the language model backbone + and is stored in the visual config of the model if we have one. + """ + + +QwenImageInputs = Union[QwenImagePixelInputs, QwenImageEmbeddingInputs] + + +class VisualAttention(nn.Module): + """self-attention layer class. + Self-attention layer takes input with size [s, b, h] + and returns output of the same size. + """ + + def __init__( + self, + embed_dim: int, + num_heads: int, + bias: bool = True, + kdim: Optional[int] = None, + vdim: Optional[int] = None, + ): + super().__init__() + self.embed_dim = embed_dim + self.kdim = kdim if kdim is not None else embed_dim + self.vdim = vdim if vdim is not None else embed_dim + self._qkv_same_embed_dim = self.kdim == embed_dim \ + and self.vdim == embed_dim + + self.num_heads = num_heads + + # Per attention head and per partition values. + assert embed_dim % num_heads == 0 + self.hidden_size_per_attention_head = embed_dim // num_heads + self.num_attention_heads_per_partition = num_heads + self.hidden_size_per_partition = embed_dim + + # Strided linear layer. + assert self._qkv_same_embed_dim, \ + 'Visual Attention implementation only supports self-attention' + self.in_proj = nn.Linear(embed_dim, 3 * embed_dim) + self.out_proj = nn.Linear(embed_dim, embed_dim) + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # query/key/value: [sq, b, h] + sq, b, _ = x.size() + mixed_x_layer = self.in_proj(x) + + # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] + new_tensor_shape = mixed_x_layer.size()[:-1] + \ + (self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head) + mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) + + # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] + query_layer, key_layer, value_layer = mixed_x_layer.split( + self.hidden_size_per_attention_head, dim=-1) + + # [sq, b, np, hn] -> [sq, b * np, hn] + query_layer = query_layer.view( + sq, b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + # [sk, b, np, hn] -> [sk, b * np, hn] + key_layer = key_layer.view( + sq, b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + + q_scaled = query_layer / self.norm_factor + if attn_mask is not None: + attention_probs = torch.baddbmm(attn_mask, q_scaled, + key_layer.transpose(-2, -1)) + else: + attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1)) + attention_probs = attention_probs.softmax(dim=-1) + + value_layer = value_layer.view( + sq, b * self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head).transpose(0, 1) + + # matmul: [b * np, sq, hn] + context_layer = torch.bmm(attention_probs, value_layer) -from .utils import is_pp_missing_parameter, make_layers + # change view [b, np, sq, hn] + context_layer = context_layer.view( + b, self.num_attention_heads_per_partition, sq, + self.hidden_size_per_attention_head) + + # [b, np, sq, hn] --> [sq, b, np, hn] + context_layer = context_layer.permute(2, 0, 1, 3).contiguous() + + # [sq, b, np, hn] --> [sq, b, hp] + new_context_layer_shape = context_layer.size()[:-2] + \ + (self.hidden_size_per_partition,) + context_layer = context_layer.view(*new_context_layer_shape) + + output = self.out_proj(context_layer) + + return output + + +class QwenVMLP(nn.Module): + """MLP for the visual component of the Qwen model.""" + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.c_fc = ColumnParallelLinear(hidden_size, + intermediate_size, + bias=True, + quant_config=quant_config) + self.act_fn = get_act_fn("gelu", quant_config, intermediate_size) + self.c_proj = RowParallelLinear( + intermediate_size, + hidden_size, + bias=True, + quant_config=quant_config, + ) + + def forward(self, x): + x, _ = self.c_fc(x) + x = self.act_fn(x) + x, _ = self.c_proj(x) + return x + + +class VisualAttentionBlock(nn.Module): + + def __init__( + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + norm_layer: Callable = nn.LayerNorm, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + + self.ln_1 = norm_layer(d_model) + self.ln_2 = norm_layer(d_model) + mlp_width = int(d_model * mlp_ratio) + self.attn = VisualAttention(d_model, n_head) + self.mlp = QwenVMLP( + hidden_size=d_model, + intermediate_size=mlp_width, + quant_config=quant_config, + ) + + def attention( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None + return self.attn(x, attn_mask=attn_mask) + + def forward( + self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) + x = x + self.mlp(self.ln_2(x)) + return x + + +class TransformerBlock(nn.Module): + + def __init__( + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + norm_layer: Callable = nn.LayerNorm, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__() + self.width = width + self.layers = layers + + self.resblocks = nn.ModuleList([ + VisualAttentionBlock(width, + heads, + mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config) + for _ in range(layers) + ]) + + def get_cast_dtype(self) -> torch.dtype: + return self.resblocks[0].mlp.c_fc.weight.dtype + + def get_cast_device(self) -> torch.device: + return self.resblocks[0].mlp.c_fc.weight.device + + def forward(self, + x: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + for r in self.resblocks: + x = r(x, attn_mask=attn_mask) + return x + + +class VisionTransformer(nn.Module): + + def __init__(self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + n_queries: int = 256, + output_dim: int = 512, + image_start_id: int = 151857, + quant_config: Optional[QuantizationConfig] = None, + **kwargs): + super().__init__() + image_height, image_width = self.image_size = (image_size, image_size) + patch_height, patch_width = self.patch_size = (patch_size, patch_size) + self.grid_size = (image_height // patch_height, + image_width // patch_width) + self.output_dim = output_dim + self.conv1 = nn.Conv2d(in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False) + + # class embeddings and positional embeddings + scale = width**-0.5 + self.positional_embedding = nn.Parameter(scale * + torch.randn(256, width)) + + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.ln_pre = norm_layer(width) + self.transformer = TransformerBlock(width, + layers, + heads, + mlp_ratio, + norm_layer=norm_layer, + quant_config=quant_config) + + self.attn_pool = Resampler2( + grid_size=int(math.sqrt(n_queries)), + embed_dim=output_dim, + num_heads=output_dim // 128, + kv_dim=width, + norm_layer=norm_layer, + adaptive=False, + do_post_projection=False, + ).to( + device=self.positional_embedding.device, + dtype=self.positional_embedding.dtype, + ) + + self.ln_post = norm_layer(output_dim) + self.proj = nn.Parameter( + (output_dim**-0.5) * torch.randn(output_dim, output_dim)) + self.image_start_id = image_start_id + self.image_end_id = image_start_id + 1 + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.to( + dtype=self.transformer.get_cast_dtype(), + device=self.transformer.get_cast_device(), + ) + + # to patches + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], + -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + + x = x + get_abs_pos(self.positional_embedding, int(math.sqrt( + x.size(1)))) + + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.transformer(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.attn_pool(x) + x = self.ln_post(x) + x = x @ self.proj + + return x + + def get_image_positions(self, + input_ids: torch.Tensor) -> Optional[torch.Tensor]: + """Given the input IDs, extracts start/stop points corresponding to + images. + + args: + Returns: + Optional torch tensor corresponding to start/stop pairs of images. + """ + if torch.any(input_ids == self.image_start_id): + bos_pos = torch.where(input_ids == self.image_start_id) + eos_pos = torch.where(input_ids == self.image_end_id) + return torch.stack((bos_pos[0], eos_pos[0]), dim=1) + return None class QWenMLP(nn.Module): + """MLP for the language component of the Qwen model, which contains a + MergedColumnParallelLinear merging 2 outputs via silu activation.""" def __init__( self, @@ -56,7 +422,7 @@ def __init__( "Only silu is supported for now.") self.act_fn = SiluAndMul() - def forward(self, x): + def forward(self, x: torch.Tensor) -> torch.Tensor: gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.c_proj(x) @@ -203,6 +569,9 @@ def __init__( lambda prefix: QWenBlock(config, cache_config, quant_config), prefix=f"{prefix}.h") self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.visual = VisionTransformer(**config.visual, + quant_config=quant_config) if hasattr( + config, "visual") else None def forward( self, @@ -211,9 +580,33 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], + pixel_values: Optional[QwenImageInputs], ) -> torch.Tensor: + img_pos = None + # If pixel / visual embeddings are provided, this is a visual model + if pixel_values is not None and self.visual is not None: + if pixel_values["type"] != "image_embeds": + image_embeds = self.visual(pixel_values["data"]) + else: + image_embeds = pixel_values["data"] + + # features should be of shape (# images, 256, hidden_dim) + img_pos = self.visual.get_image_positions(input_ids) + if isinstance( + img_pos, + np.ndarray) and img_pos.shape[0] != image_embeds.shape[0]: + raise ValueError( + f"Number of placeholders: {img_pos.shape[0]} " + f"does not match number of images {image_embeds.shape[0]}." + ) + if get_pp_group().is_first_rank: hidden_states = self.wte(input_ids) + # Merge the image embeddings into the hidden states if actually have + # visual features and the corresponding image tokens + if img_pos is not None: + for idx, (img_bos, img_eos) in enumerate(img_pos): + hidden_states[img_bos + 1:img_eos] = image_embeds[idx] residual = None else: assert intermediate_tensors is not None @@ -237,16 +630,241 @@ def forward( return hidden_states -class QWenLMHeadModel(nn.Module): +def get_image_text(image_num: int, padding: bool) -> str: + """Retrieves a placeholder text that when tokenized, will be expanded with + image pads. + + Args: + image_num: The number of the image that we want a text prompt for. + Images should be indexed starting at 1. + padding: Whether or not padding should be manually added. + + Returns: + Text placeholder prompt for the image being considered. + """ + image_start = f"Picture {image_num}: {IMG_START}" + image_end = f"{IMG_END}\n" + if not padding: + return f"{image_start}{image_end}" + return f"{image_start}{MAX_QWEN_IMG_TOKENS * IMG_PAD}{image_end}" + + +def input_processor_for_qwen(ctx: InputContext, + llm_inputs: LLMInputs) -> LLMInputs: + """Processes the inputs, which may or may not be multimodal. + Multimodal inputs will only be processed if the model has a "visual" + component in its model config, otherwise they'll be ignored. + + Args: + ctx: Context of the loaded model. + llm_inputs: LLM inputs which may have a multi_modal_data attribute. + + Returns: + If the model is language only or not multimodal inputs were provided, + returns llm_inputs unmodified. Otherwise, processes the multimodal + images / image embeddings and adds the fixed-length image placeholders. + """ + multi_modal_data = llm_inputs.get("multi_modal_data") + + # Only process images if we have multimodal data and a visual config + hf_config = ctx.get_hf_config() + if (multi_modal_data is None or "image" not in multi_modal_data + or not hasattr(hf_config, "visual")): + return llm_inputs + + prompt = llm_inputs.get("prompt") + prompt_token_ids = llm_inputs["prompt_token_ids"] + model_config = ctx.model_config + tokenizer = cached_get_tokenizer(model_config.tokenizer, + trust_remote_code=True) + image_data = multi_modal_data["image"] + if isinstance(image_data, torch.Tensor): + num_dims = len(image_data.shape) + if num_dims < 2 or num_dims > 3: + raise ValueError( + f"Expected img embeds to be have 3 dimensions, got {num_dims}") + num_images = 1 if num_dims == 2 else image_data.shape[0] + else: + # TODO - handle multiple image inputs once the API is solidified + num_images = 1 + + if prompt is None: + prompt = tokenizer.decode(prompt_token_ids) + + # Drops anything between / tags; encoding with the tokenizer + # will automatically add the image pads for the context. + new_prompt, num_matched_images = re.subn( + r"(Picture \d*: ).*?(<\/img>\n)", + r"\1\2", + prompt, + ) + + if num_matched_images != num_images: + logger.warning( + "Number of matched image placeholders %s doesn't match the number " + "of expected images %s; check your placeholder formatting.", + num_matched_images, num_images) + + new_prompt_token_ids = tokenizer.encode(new_prompt) + + return LLMInputs(prompt=new_prompt, + prompt_token_ids=new_prompt_token_ids, + multi_modal_data=multi_modal_data) + + +def input_mapper_for_qwen(ctx: InputContext, data: object) -> MultiModalInputs: + """Maps the input data to its MultiModalInputs (if any). + + Args: + ctx: Context of the loaded model. + data: data potentially containing image/image embeddings to be mapped + to pixel_values in .forward() for a visual QWenLMHeadModel model. + + Returns: + MultiModalInputs containing the stacked normalized images tensor or + image embeddings. + """ + # Early exit if we have provided an image to a language only Qwen model + hf_config = ctx.get_hf_config() + if not hasattr(hf_config, "visual"): + logger.warning( + "Images were provided but this model has no visual config; " + "multimodal inputs will not be forwarded to the model.") + return MultiModalInputs() + + model_config = ctx.model_config + tokenizer = cached_get_tokenizer(model_config.tokenizer, + trust_remote_code=True) + + image_pair_tok = tokenizer.encode(IMG_START + IMG_END, + add_special_tokens=False, + return_tensors="pt").squeeze() + image_start_id = image_pair_tok[0] + image_end_id = image_pair_tok[-1] + if (image_start_id + 1) != image_end_id: + raise ValueError( + f"Found image end ID {image_end_id}, but expected {IMG_START} + 1") + if len(image_pair_tok) != (MAX_QWEN_IMG_TOKENS + 2): + raise ValueError( + f"Expected image context length of {MAX_QWEN_IMG_TOKENS}, " + f"but got {image_pair_tok - 2}") + + hf_config = ctx.get_hf_config() + image_size = hf_config.visual["image_size"] + img_emb_size = hf_config.visual["output_dim"] + + if isinstance(data, torch.Tensor): + # It's expected that our values have already been processed + # by the visual transformer; shape is expected to be: + # (# images, 256, hidden_size) + if len(data.shape) == 2: + # Assume only one image embed was provided; unsqueeze the extra dim + data = data.unsqueeze(0) + if len(data.shape) != 3 or data.shape[ + 1] != MAX_QWEN_IMG_TOKENS or data.shape[2] != img_emb_size: + raise ValueError( + "Expected image embeds to be a tensor of shape" + f"[# images, {MAX_QWEN_IMG_TOKENS}, {img_emb_size}], but " + f"received shape [{data.shape}]") + pixel_values = data + + else: + transform = build_normalization_transform(image_size) + # TODO - handle multiple image inputs once the API is solidified + transformed_images = [transform(data)] + pixel_values = torch.stack(transformed_images, dim=0) + return MultiModalInputs({"pixel_values": pixel_values}) + + +def build_normalization_transform(image_size: int) -> transforms.Compose: + """Builds a normalization transform which can be applied to one or + more input images from which we want to extract visual features. + + Args: + image_size: size of the image to be processed for visual embeddings. + + Returns: + Callable transform for normalizing and resizing one RGB image. + """ + return transforms.Compose([ + transforms.Resize((image_size, image_size), + interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=CLIP_MEAN, std=CLIP_STD), + ]) + + +def dummy_data_for_qwen( + ctx: InputContext, + seq_len: int, + mm_counts: Mapping[str, int], +) -> Tuple[SequenceData, Optional[Dict]]: + """Build dummy data for warming up Qwen models; this will only contain text + matching the defaults for VLLM unless the model has a visual config. + + Args: + ctx: Context of the loaded model. + seq_len: Number of tokens in the text sequence. + mm_counts: multimodal data counts. + + Returns: + Tuple containing sequential and multimodal data. + """ + hf_config = ctx.get_hf_config() + + # The presence of a visual config indicates this is a multimodal model. + # If we don't have it, the model is considered an LLM for warmup purposes. + if not hasattr(hf_config, "visual"): + seq_data = SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, [0] * seq_len)) + mm_data = None + return seq_data, mm_data + + # We have a visual component - use images to warm up + num_images = mm_counts["image"] + model_config = ctx.model_config + tokenizer = cached_get_tokenizer(model_config.tokenizer, + trust_remote_code=True) + + # Build the image prompts with no imgpads; the tokenizer will add img pads + image_prompt = ''.join( + [get_image_text(idx, False) for idx in range(1, num_images + 1)]) + toks = tokenizer.encode(image_prompt, add_special_tokens=False) + + # Make sure we actually get the fixed context size per tok padding + num_pads = toks.count(tokenizer.encode(IMG_PAD)[0]) + if num_pads != (num_images * MAX_QWEN_IMG_TOKENS): + raise ValueError( + f"Tokenized dummy data should encode {MAX_QWEN_IMG_TOKENS} pads" + f" per image, but got {num_pads} pads for {num_images} image(s)" + " in total. Are you using a qwen tokenizer?") + + # Ensure the number of tokens is at minimum the sequence length provided + if len(toks) < seq_len: + toks += [0] * (seq_len - len(toks)) + + # Build the input images; width/height doesn't actually matter here since + # the data will get resized and the # of tokens per image is constant + image = Image.new("RGB", (224, 224), color=0) + mm_data = {"image": image if num_images == 1 else [image] * num_images} + return SequenceData(array(VLLM_TOKEN_ID_ARRAY_TYPE, toks)), mm_data + + +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_qwen) +@MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS) +@INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen) +@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen) +class QWenLMHeadModel(nn.Module, SupportsMultiModal): def __init__( self, config: PretrainedConfig, + multimodal_config: MultiModalConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config + self.multimodal_config = multimodal_config self.quant_config = quant_config self.transformer = QWenModel(config, cache_config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, @@ -257,16 +875,47 @@ def __init__( self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + def _get_image_input_type( + self, + pixel_values: Optional[torch.Tensor]) -> Optional[QwenImageInputs]: + """Determines if the provided pixel_values are normalized pixel values + or image embeddings. + + Args: + pixel_values: Optional data to processed into visual embeddings. + + Returns: + None of the QwenImageInputs type used to determine whether or not + the visual transformer needs to process the pixel_values. + """ + if pixel_values is not None and self.transformer.visual is not None: + pixel_values = flatten_bn(pixel_values) + if len(pixel_values.shape) == 3 and pixel_values.shape[ + 1] == MAX_QWEN_IMG_TOKENS and pixel_values.shape[ + 2] == self.config.visual["output_dim"]: + return QwenImageEmbeddingInputs( + type="image_embeds", + data=pixel_values, + ) + else: + # If we have the wrong shape, assume we still need to process + return QwenImagePixelInputs( + type="pixel_values", + data=pixel_values, + ) + return None + + def forward(self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + pixel_values: Optional[torch.Tensor] = None) -> torch.Tensor: + pixel_values = self._get_image_input_type(pixel_values) hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + pixel_values) return hidden_states def make_empty_intermediate_tensors( @@ -328,15 +977,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue - # Skip loading visual weights to support Qwen-VL models - # in cases with text-only inputs - # TODO: add support for Qwen-VL - if (name not in params_dict - and name.startswith("transformer.visual.")): - print_warning_once( - "Only text inputs are allowed. Images won't be handled " - "until Qwen-VL models are fully supported.") - continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index b95987c16ebca..a64e08c422bc3 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -40,13 +40,13 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA from .utils import is_pp_missing_parameter, make_layers diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 6f838947fbf27..56129515ca8d1 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -45,12 +45,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from vllm.utils import print_warning_once from .utils import is_pp_missing_parameter, make_layers diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 073f60bb3a056..13d09e4cd4c23 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -9,12 +9,10 @@ from PIL import Image from torch import nn from transformers import SiglipVisionConfig -from transformers.models.siglip.modeling_siglip import SiglipAttention -from vllm_flash_attn import flash_attn_func -from xformers.ops import memory_efficient_attention +from transformers.models.siglip.modeling_siglip import SiglipSdpaAttention from vllm.config import ModelConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import LLMInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -28,6 +26,12 @@ repeat_and_pad_placeholder_tokens) from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +try: + from xformers import ops as xops + USE_XFORMERS_OPS = True +except ImportError: + USE_XFORMERS_OPS = False + def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int: # Since interpolation is applied, the image size need not be divisible @@ -106,7 +110,7 @@ def input_processor_for_siglip( if isinstance(image_data, Image.Image): image_feature_size = get_siglip_image_feature_size(hf_config) elif isinstance(image_data, torch.Tensor): - image_feature_size = image_data.shape[0] + num_images, image_feature_size, hidden_size = image_data.shape else: raise TypeError(f"Invalid image type: {type(image_data)}") else: @@ -221,9 +225,7 @@ def forward(self, return embeddings -# NOTE: Not used - kept for later when we TP the ViT -# TODO(ChristopherCho): Implement TP version of Attention -class SiglipTPAttention(nn.Module): +class SiglipParallelAttention(nn.Module): def __init__( self, @@ -233,38 +235,30 @@ def __init__( super().__init__() self.config = config self.embed_dim = config.hidden_size - - tp_size = get_tensor_model_parallel_world_size() - self.total_num_heads = config.num_attention_heads - if self.total_num_heads % tp_size != 0: - raise ValueError( - f"Number of attention heads ({self.total_num_heads}) " - "must be divisible by the tensor model parallel size" - f" ({tp_size}).") - - self.num_heads = self.total_num_heads // tp_size - self.head_dim = self.embed_dim // self.total_num_heads - if self.head_dim * self.total_num_heads != self.embed_dim: + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: raise ValueError(f"embed_dim must be divisible by num_heads (got " "`embed_dim`: {self.embed_dim} and `num_heads`:" f" {self.num_heads}).") - self.qkv_size = self.num_heads * self.head_dim + self.scale = self.head_dim**-0.5 self.dropout = config.attention_dropout - self.qkv_proj = QKVParallelLinear( hidden_size=self.embed_dim, head_size=self.head_dim, - total_num_heads=self.total_num_heads, + total_num_heads=self.num_heads, quant_config=quant_config, ) + self.out_proj = RowParallelLinear( input_size=self.embed_dim, output_size=self.embed_dim, quant_config=quant_config, ) - self.attn_fn = self._basic_attention_forward + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_partition = divide(self.num_heads, self.tp_size) def forward( self, @@ -274,161 +268,27 @@ def forward( batch_size, q_len, _ = hidden_states.size() qkv_states, _ = self.qkv_proj(hidden_states) - query_states, key_states, value_states = qkv_states.split( - [self.qkv_size] * 3, dim=-1) - - attn_output = self.attn_fn( - q=query_states, - k=key_states, - v=value_states, - batch_size=batch_size, - q_len=q_len, - ) - - attn_output, _ = self.out_proj(attn_output) - return attn_output - - def _basic_attention_forward(self, q, k, v, batch_size, q_len): - q = q.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - k = k.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - v = v.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - - k_v_seq_len = k.shape[-2] - attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scale - - if attn_weights.size() != ( - batch_size, - self.num_heads, - q_len, - k_v_seq_len, - ): - raise ValueError( - "Attention weights should be of size " - f"{(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" - f" {attn_weights.size()}") - - # upcast attention to fp32 - attn_weights = nn.functional.softmax(attn_weights, - dim=-1, - dtype=torch.float32).to(q.dtype) - attn_weights = nn.functional.dropout(attn_weights, - p=self.dropout, - training=self.training) - attn_output = torch.matmul(attn_weights, v) - - if attn_output.size() != ( - batch_size, - self.num_heads, - q_len, - self.head_dim, - ): - raise ValueError( - "`attn_output` should be of size " - f"{(batch_size, self.num_heads, q_len, self.head_dim)}, but is" - f" {attn_output.size()}") - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) - - return attn_output - - -# NOTE: Not used - kept for later when we TP the ViT -# TODO(ChristopherCho): flash_attn_func is not working properly. -# It constantly throws a CUDA error. -class SiglipFlashAttention2(SiglipTPAttention): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.attn_fn = self._flash_attention_forward - - # Ported from https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/siglip/modeling_siglip.py#L449 - # and https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/modeling_flash_attention_utils.py#L133 - def _flash_attention_forward(self, q, k, v, batch_size, q_len, *args, - **kwargs): - """Implements the multihead softmax attention. - Arguments - --------- - q, k, v: The tensor containing the - query, key, and value. (B, S, H, D) - """ - - q = q.view(batch_size, q_len, self.num_heads, self.head_dim) - k = k.view(batch_size, q_len, self.num_heads, self.head_dim) - v = v.view(batch_size, q_len, self.num_heads, self.head_dim) - - attn_output = flash_attn_func( - q, - k, - v, - dropout_p=self.dropout, - causal=False, - ) - - attn_output = attn_output.reshape(batch_size, q_len, - self.embed_dim).contiguous() - - return attn_output - - -# NOTE: Not used - kept for later when we TP the ViT -class SiglipSdpaAttention(SiglipTPAttention): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.is_causal = False - self.attn_fn = self._sdpa_attention_forward - - def _sdpa_attention_forward(self, q, k, v, batch_size, q_len): - q = q.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - k = k.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) - v = v.view(batch_size, q_len, self.num_heads, - self.head_dim).transpose(1, 2) + query_states, key_states, value_states = qkv_states.chunk(3, dim=-1) - attn_output = torch.nn.functional.scaled_dot_product_attention( - q, k, v, dropout_p=self.dropout, is_causal=False, scale=self.scale) + query_states = query_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + key_states = key_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) + value_states = value_states.view(batch_size, q_len, + self.num_heads_per_partition, + self.head_dim) - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.view(batch_size, q_len, self.embed_dim) + out = xops.memory_efficient_attention_forward(query_states, + key_states, + value_states, + p=self.dropout, + scale=self.scale) + out = out.view(batch_size, q_len, -1) + attn_output, _ = self.out_proj(out) - return attn_output - - -# NOTE: Not used - kept for later when we TP the ViT -class SiglipxFormersAttention(SiglipTPAttention): - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.attn_fn = self._xformers_attention_forward - - def _xformers_attention_forward(self, q, k, v, batch_size, q_len): - q = q.view(batch_size, q_len, self.num_heads, self.head_dim) - k = k.view(batch_size, q_len, self.num_heads, self.head_dim) - v = v.view(batch_size, q_len, self.num_heads, self.head_dim) - - attn_output = memory_efficient_attention(q, - k, - v, - p=0.0, - scale=self.scale) - attn_output = attn_output.reshape(batch_size, q_len, - self.embed_dim).contiguous() - - return attn_output - - -# NOTE: Not used - kept for later when we TP the ViT -SIGLIP_ATTENTION_CLASSES = { - "eager": SiglipTPAttention, - "flash_attention_2": SiglipFlashAttention2, - "sdpa": SiglipSdpaAttention, - "xformers": SiglipxFormersAttention, -} + return attn_output, None class SiglipMLP(nn.Module): @@ -473,8 +333,14 @@ def __init__( super().__init__() self.embed_dim = config.hidden_size - # TODO(ChristopherCho): use TP'ed Attention block - self.self_attn = SiglipAttention(config) + num_heads = config.num_attention_heads + tp_size = get_tensor_model_parallel_world_size() + if USE_XFORMERS_OPS and num_heads % tp_size == 0: + self.self_attn = SiglipParallelAttention(config, + quant_config=quant_config) + else: + self.self_attn = SiglipSdpaAttention(config) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) self.mlp = SiglipMLP( @@ -577,14 +443,27 @@ def __init__( self.config = config embed_dim = config.hidden_size + if (num_hidden_layers_override is None + or num_hidden_layers_override == config.num_hidden_layers): + self.need_post_layernorm = True + elif num_hidden_layers_override > config.num_hidden_layers: + raise ValueError( + "num_hidden_layers_override cannot be greater than " + "num_hidden_layers") + else: + self.need_post_layernorm = False + self.embeddings = SiglipVisionEmbeddings(config) self.encoder = SiglipEncoder( config, quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, ) - self.post_layernorm = nn.LayerNorm(embed_dim, - eps=config.layer_norm_eps) + if self.need_post_layernorm: + self.post_layernorm = nn.LayerNorm(embed_dim, + eps=config.layer_norm_eps) + else: + self.post_layernorm = nn.Identity() self.use_head = (True if not hasattr(config, "vision_use_head") else config.vision_use_head) if self.use_head: @@ -604,7 +483,6 @@ def forward( encoder_outputs = self.encoder(inputs_embeds=hidden_states) last_hidden_state = self.post_layernorm(encoder_outputs) - # TODO: add this back when pooled_output is used in inference # if self.use_head: # pooled_output = self.head(last_hidden_state) @@ -623,12 +501,20 @@ def __init__( num_hidden_layers_override: Optional[int] = None, ): super().__init__() + num_heads = config.num_attention_heads + tp_size = get_tensor_model_parallel_world_size() + self.shard_weight = USE_XFORMERS_OPS and num_heads % tp_size == 0 + self.vision_model = SiglipVisionTransformer( config, quant_config, num_hidden_layers_override=num_hidden_layers_override, ) + @property + def need_post_layernorm(self): + return self.vision_model.need_post_layernorm + def get_input_embeddings(self) -> nn.Module: return self.vision_model.embeddings.patch_embedding @@ -643,17 +529,37 @@ def forward( ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] if self.shard_weight else [] params_dict = dict(self.named_parameters()) layer_count = len(self.vision_model.encoder.layers) for name, loaded_weight in weights: + # post_layernorm is optional in SiglipVisionModel + if ("vision_model.post_layernorm" in name + and not self.need_post_layernorm): + continue + # omit layers when num_hidden_layers_override is set if "vision_model.encoder.layers." in name: layer_idx = int(name.split(".")[3]) if layer_idx >= layer_count: continue - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index decbf89d27c7c..6236426dcd4e1 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -36,12 +36,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class StablelmMLP(nn.Module): diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index d1b1d210b727c..d3a3a83c8437f 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -35,12 +35,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors class Starcoder2Attention(nn.Module): diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index c81c2fd114eb8..416fabda831a2 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -8,7 +8,6 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union, cast) -import librosa import numpy as np import torch import torch.utils.checkpoint @@ -27,17 +26,18 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.interfaces import SupportsMultiModal -from vllm.model_executor.models.utils import (filter_weights, +from vllm.model_executor.models.utils import (filter_weights, flatten_bn, init_vllm_registered_model, merge_multimodal_embeddings) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.base import MultiModalInputs, NestedTensors from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SamplerOutput, SequenceData +from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData from vllm.transformers_utils.configs.ultravox import UltravoxConfig _AUDIO_PLACEHOLDER_TOKEN = 128002 @@ -48,13 +48,14 @@ class UltravoxAudioFeatureInputs(TypedDict): type: Literal["audio_features"] - data: Union[torch.Tensor, List[torch.Tensor]] - """Shape: `(batch_size, 80, M)""" + data: NestedTensors + """Shape: `(batch_size, num_audios, 80, M)""" class UltravoxAudioEmbeddingInputs(TypedDict): type: Literal["audio_embeds"] - data: torch.Tensor + data: NestedTensors + """Shape: `(batch_size, num_audios, audio_feature_size, hidden_size)""" UltravoxAudioInputs = Union[UltravoxAudioFeatureInputs, @@ -85,27 +86,41 @@ def dummy_data_for_ultravox( audio_count = mm_counts["audio"] - audio_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [ - _AUDIO_PLACEHOLDER_TOKEN - ]) * get_ultravox_max_audio_tokens(ctx) * audio_count + audio_placeholder = array( + VLLM_TOKEN_ID_ARRAY_TYPE, + [_AUDIO_PLACEHOLDER_TOKEN]) * get_ultravox_max_audio_tokens(ctx) + + # Add a separator between each chunk. + audio_token_ids = (audio_placeholder + + array(VLLM_TOKEN_ID_ARRAY_TYPE, [0])) * audio_count other_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE, [0]) * (seq_len - len(audio_token_ids)) audio_and_sr = (np.array([0.0] * feature_extractor.chunk_length), 1) - mm_dict = { - "audio": - audio_and_sr if audio_count == 1 else [audio_and_sr] * audio_count - } + mm_dict = {"audio": [audio_and_sr] * audio_count} return (SequenceData(audio_token_ids + other_token_ids), mm_dict) def input_mapper_for_ultravox(ctx: InputContext, data: object): - if isinstance(data, tuple): - (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], data) + if not isinstance(data, list): + data = [data] + + audio_features = [] + for audio_input in data: + if not isinstance(audio_input, tuple): + raise NotImplementedError( + f"Unsupported data type: {type(audio_input)}") + + (audio, sr) = cast(Tuple[np.ndarray, Union[float, int]], audio_input) feature_extractor = whisper_feature_extractor(ctx) if sr != feature_extractor.sampling_rate: + try: + import librosa + except ImportError: + raise ImportError( + "Please install vllm[audio] for audio support.") from None audio = librosa.resample(audio, orig_sr=sr, target_sr=feature_extractor.sampling_rate) @@ -116,15 +131,14 @@ def input_mapper_for_ultravox(ctx: InputContext, data: object): # Not enough audio; pad it. audio = np.pad(audio, (0, minimum_audio_length - len(audio))) - return MultiModalInputs({ - "audio_features": - feature_extractor(audio, - sampling_rate=sr, - padding="longest", - return_tensors="pt")["input_features"] - }) + single_audio_features = feature_extractor( + audio, sampling_rate=sr, padding="longest", + return_tensors="pt")["input_features"] - raise NotImplementedError(f"Unsupported data type: {type(data)}") + # Remove the batch dimension because we're wrapping it in a list. + audio_features.append(single_audio_features.squeeze(0)) + + return MultiModalInputs({"audio_features": audio_features}) def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): @@ -133,25 +147,31 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): return llm_inputs feature_extractor = whisper_feature_extractor(ctx) - audio_data, sample_rate = multi_modal_data["audio"] - - audio_length = audio_data.shape[0] - if sample_rate != feature_extractor.sampling_rate: - # Account for resampling. - adjustment = feature_extractor.sampling_rate / sample_rate - audio_length = math.ceil(adjustment * audio_length) - - feature_extractor_output_length = math.ceil( - (audio_length - - (feature_extractor.hop_length - 1)) / feature_extractor.hop_length) - - uv_config = ctx.get_hf_config(UltravoxConfig) - audio_num_tokens = min( - max( - 1, - math.ceil(feature_extractor_output_length / - (uv_config.stack_factor * 2))), - get_ultravox_max_audio_tokens(ctx)) + audios = multi_modal_data["audio"] + if not isinstance(audios, list): + audios = [audios] + + audio_token_counts = [] + for audio_data, sample_rate in audios: + audio_length = audio_data.shape[0] + if sample_rate != feature_extractor.sampling_rate: + # Account for resampling. + adjustment = feature_extractor.sampling_rate / sample_rate + audio_length = math.ceil(adjustment * audio_length) + + feature_extractor_output_length = math.ceil( + (audio_length - (feature_extractor.hop_length - 1)) / + feature_extractor.hop_length) + + uv_config = ctx.get_hf_config(UltravoxConfig) + audio_num_tokens = min( + max( + 1, + math.ceil(feature_extractor_output_length / + (uv_config.stack_factor * 2))), + get_ultravox_max_audio_tokens(ctx)) + audio_token_counts.append(audio_num_tokens) + tokenizer = cached_get_tokenizer(ctx.model_config.tokenizer) new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( @@ -159,7 +179,7 @@ def input_processor_for_ultravox(ctx: InputContext, llm_inputs: LLMInputs): llm_inputs.get("prompt"), llm_inputs["prompt_token_ids"], placeholder_token_id=_AUDIO_PLACEHOLDER_TOKEN, - repeat_count=audio_num_tokens, + repeat_count=audio_token_counts, ) # NOTE: Create a defensive copy of the original inputs @@ -333,45 +353,52 @@ def _parse_and_validate_audio_input( raise ValueError("Incorrect type of audio features. " f"Got type: {type(audio_features)}") - # Remove the N dimension until multiple audios are supported. - if isinstance(audio_features, torch.Tensor): - audio_features = audio_features.squeeze(1) - else: - audio_features = [t.squeeze(0) for t in audio_features] - return UltravoxAudioFeatureInputs(type="audio_features", data=audio_features) if audio_embeds is not None: - if not isinstance(audio_embeds, torch.Tensor): + if not isinstance(audio_embeds, (torch.Tensor, list)): raise ValueError("Incorrect type of audio embeds. " f"Got type: {type(audio_embeds)}") - # Remove the N dimension until multiple audios are supported. - audio_embeds = audio_embeds.squeeze(1) - return UltravoxAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) raise AssertionError("This line should be unreachable.") def _process_audio_input( - self, audio_input: UltravoxAudioInputs - ) -> Union[torch.Tensor, List[torch.Tensor]]: + self, audio_input: UltravoxAudioInputs) -> NestedTensors: if audio_input["type"] == "audio_embeds": return audio_input["data"] audio_features = audio_input["data"] - if isinstance(audio_features, list): - # TODO: Batch these through the encoder/projector instead of - # serializing them. - return [ - self._audio_features_to_embeddings( - features.unsqueeze(0)).squeeze(0) - for features in audio_features - ] - else: - return self._audio_features_to_embeddings(audio_features) + if isinstance(audio_features, torch.Tensor): + # Combine the B and N dimensions for the encoder/projector + flattened = flatten_bn(audio_features) + flattened_embeddings = self._audio_features_to_embeddings( + flattened) + + # Restore the original dimensions + embeddings = flattened_embeddings.unflatten( + 0, audio_features.shape[:2]) + return embeddings + + result = [] + # TODO: Batch heterogeneous tensors through the encoder/projector + for audio_features_item in audio_features: + if isinstance(audio_features_item, torch.Tensor): + result.append( + self._audio_features_to_embeddings(audio_features_item)) + else: + embeddings = [ + # Add a batch dimension to embed it, then remove it. + self._audio_features_to_embeddings(tensor.unsqueeze(0) + ).squeeze(0) + for tensor in audio_features_item + ] + result.append(embeddings) + + return result def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], @@ -388,7 +415,7 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, with the `input_ids`. Args: - input_features: A batch of audio inputs, [1, 80, M]. + audio_features: A batch of audio inputs [B, N, 80, M]. """ audio_input = self._parse_and_validate_audio_input(**kwargs) if audio_input is not None: diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 00026b7ebe2e1..8b80dda96db49 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,6 +1,6 @@ -from typing import Dict, Iterable, List, Optional, Protocol, Tuple +from typing import (Dict, Iterable, List, Literal, Optional, Protocol, Tuple, + Union, overload) -import numpy as np import torch import torch.nn as nn from torch.func import functional_call @@ -12,6 +12,7 @@ from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.models import ModelRegistry from vllm.multimodal.base import NestedTensors +from vllm.sequence import IntermediateTensors from vllm.utils import is_pin_memory_available @@ -55,14 +56,53 @@ def init_vllm_registered_model( ) +@overload +def flatten_bn(x: torch.Tensor) -> torch.Tensor: + ... + + +@overload +def flatten_bn(x: List[torch.Tensor]) -> List[torch.Tensor]: + ... + + +@overload +def flatten_bn( + x: Union[List[torch.Tensor], torch.Tensor], + *, + concat: Literal[True], +) -> torch.Tensor: + ... + + +def flatten_bn( + x: Union[List[torch.Tensor], torch.Tensor], + *, + concat: bool = False, +) -> Union[List[torch.Tensor], torch.Tensor]: + """ + Flatten the ``B`` and ``N`` dimensions of batched multimodal inputs. + + The input tensor should have shape ``(B, N, ...)```. + """ + if isinstance(x, torch.Tensor): + return x.flatten(0, 1) + + if concat: + return torch.cat(x) + + return [x_n for x_b in x for x_n in x_b] + + def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor: """ - Recursively concatenates NestedTensors along any heterogeneously sized - dimensions. + Recursively flattens and concatenates NestedTensors on all but the last + dimension. """ if isinstance(embeddings, torch.Tensor): - return embeddings + # Flatten all but the last dimension. + return embeddings.flatten(0, -2) return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings)) @@ -93,18 +133,17 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor, This updates ``inputs_embeds`` in place. """ mask = (input_ids == placeholder_token_id) - num_expected_tokens = mask.sum() + num_expected_tokens = mask.sum().item() + assert isinstance(num_expected_tokens, int) flattened = _flatten_embeddings(multimodal_embeddings) - *dims, embed_dim = flattened.shape - num_multimodal_embeddings = np.prod(dims) - if num_multimodal_embeddings != num_expected_tokens: + if flattened.shape[0] != num_expected_tokens: expr = _embedding_count_expression(multimodal_embeddings) raise ValueError( - f"Attempted to assign {expr} = {num_multimodal_embeddings} " + f"Attempted to assign {expr} = {flattened.shape[0]} " f"multimodal tokens to {num_expected_tokens} placeholders") - inputs_embeds[mask] = flattened.view(num_expected_tokens, embed_dim) + inputs_embeds[mask] = flattened return inputs_embeds @@ -241,3 +280,18 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: if name.startswith(missing_layer_name): return True return False + + +def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): + + def make_empty_intermediate_tensors( + batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + key: torch.zeros((batch_size, hidden_size), + dtype=dtype, + device=device) + for key in keys + }) + + return make_empty_intermediate_tensors diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index c0bafa9367e43..24cc3728f85e4 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -38,12 +38,12 @@ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.sequence import IntermediateTensors from .interfaces import SupportsLoRA diff --git a/vllm/model_executor/parameter.py b/vllm/model_executor/parameter.py index 326b6ae8fee64..9ffb339ffeab3 100644 --- a/vllm/model_executor/parameter.py +++ b/vllm/model_executor/parameter.py @@ -1,3 +1,4 @@ +from fractions import Fraction from typing import Callable, Optional, Union import torch @@ -257,7 +258,7 @@ class PackedColumnParameter(_ColumnvLLMParameter): """ def __init__(self, - packed_factor: int, + packed_factor: Union[int, Fraction], packed_dim: int, marlin_tile_size: Optional[int] = None, **kwargs): @@ -298,7 +299,7 @@ class PackedvLLMParameter(ModelWeightParameter): """ def __init__(self, - packed_factor: int, + packed_factor: Union[int, Fraction], packed_dim: int, marlin_tile_size: Optional[int] = None, **kwargs): diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index f26e3292c264d..17ef9938d0572 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -18,7 +18,7 @@ logger = init_logger(__name__) -NestedTensors = Union[List["NestedTensors"], torch.Tensor] +NestedTensors = Union[List["NestedTensors"], List[torch.Tensor], torch.Tensor] """ Uses a list instead of a tensor if the dimensions of each element do not match. """ @@ -54,14 +54,14 @@ def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: return nested_tensors stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors] - if is_list_of(stacked, list): - # Do not stack nested lists + if not is_list_of(stacked, torch.Tensor, check="all"): + # Only tensors (not lists) can be stacked. return stacked tensors_ = cast(List[torch.Tensor], stacked) if any(t.shape != tensors_[0].shape for t in tensors_): # The tensors have incompatible shapes and can't be stacked. - return stacked + return tensors_ return torch.stack(tensors_) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 989b2e1a814c9..b76b765bc677a 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -1,11 +1,9 @@ import base64 from functools import lru_cache from io import BytesIO -from typing import List, Optional, Tuple, TypeVar, Union +from typing import Any, List, Optional, Tuple, TypeVar, Union -import librosa import numpy as np -import soundfile from PIL import Image from vllm.connections import global_http_connection @@ -73,10 +71,22 @@ async def async_fetch_image(image_url: str, return image.convert(image_mode) +def try_import_audio_packages() -> Tuple[Any, Any]: + try: + import librosa + import soundfile + except ImportError: + raise ImportError( + "Please install vllm[audio] for audio support.") from None + return librosa, soundfile + + def fetch_audio(audio_url: str) -> Tuple[np.ndarray, Union[int, float]]: """ Load audio from a URL. """ + librosa, _ = try_import_audio_packages() + if audio_url.startswith("http"): audio_bytes = global_http_connection.get_bytes( audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT) @@ -95,6 +105,8 @@ async def async_fetch_audio( """ Asynchronously fetch audio from a URL. """ + librosa, _ = try_import_audio_packages() + if audio_url.startswith("http"): audio_bytes = await global_http_connection.async_get_bytes( audio_url, timeout=VLLM_AUDIO_FETCH_TIMEOUT) @@ -108,6 +120,16 @@ async def async_fetch_audio( return librosa.load(BytesIO(audio_bytes), sr=None) +def get_and_parse_audio(audio_url: str) -> MultiModalDataDict: + audio, sr = fetch_audio(audio_url) + return {"audio": (audio, sr)} + + +def get_and_parse_image(image_url: str) -> MultiModalDataDict: + image = fetch_image(image_url) + return {"image": image} + + async def async_get_and_parse_audio(audio_url: str) -> MultiModalDataDict: audio, sr = await async_fetch_audio(audio_url) return {"audio": (audio, sr)} @@ -123,6 +145,8 @@ def encode_audio_base64( sampling_rate: int, ) -> str: """Encode audio as base64.""" + _, soundfile = try_import_audio_packages() + buffered = BytesIO() soundfile.write(buffered, audio, sampling_rate, format="WAV") diff --git a/vllm/scripts.py b/vllm/scripts.py index a9ddfcf864133..e557961a335bf 100644 --- a/vllm/scripts.py +++ b/vllm/scripts.py @@ -125,6 +125,15 @@ def main(): serve_parser.add_argument("model_tag", type=str, help="The model tag to serve") + serve_parser.add_argument( + "--config", + type=str, + default='', + required=False, + help="Read CLI options from a config file." + "Must be a YAML with the following options:" + "https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html#command-line-arguments-for-the-server" + ) serve_parser = make_arg_parser(serve_parser) serve_parser.set_defaults(dispatch_function=serve) diff --git a/vllm/sequence.py b/vllm/sequence.py index 3125acc6fd535..a5ebf152ce776 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -1060,76 +1060,6 @@ def __repr__(self) -> str: return f"IntermediateTensors(tensors={self.tensors})" -class SamplerOutput( - msgspec.Struct, - omit_defaults=True, # type: ignore[call-arg] - array_like=True): # type: ignore[call-arg] - """For each sequence group, we generate a list of SequenceOutput object, - each of which contains one possible candidate for the next token. - - This data structure implements methods, so it can be used like a list, but - also has optional fields for device tensors. - """ - - outputs: List[CompletionSequenceGroupOutput] - - # On-device tensor containing probabilities of each token. - sampled_token_probs: Optional[torch.Tensor] = None - - # On-device tensor containing the logprobs of each token. - logprobs: Optional["torch.Tensor"] = None - - # On-device tensor containing the sampled token ids. - sampled_token_ids: Optional[torch.Tensor] = None - # CPU tensor containing the sampled token ids. Used during multi-step to - # return the sampled token ids from last rank to AsyncLLMEngine to be - # 'broadcasted' to all other PP ranks for next step. - sampled_token_ids_cpu: Optional[torch.Tensor] = None - - # Spec decode metrics populated by workers. - spec_decode_worker_metrics: Optional[SpecDecodeWorkerMetrics] = None - - # Optional last hidden states from the model. - hidden_states: Optional[torch.Tensor] = None - - # Optional prefill hidden states from the model - # (used for models like EAGLE). - prefill_hidden_states: Optional[torch.Tensor] = None - - # Time taken in the forward pass for this across all workers - model_forward_time: Optional[float] = None - - # Time taken in the model execute function. This will include model forward, - # block/sync across workers, cpu-gpu sync time and sampling time. - model_execute_time: Optional[float] = None - - def __getitem__(self, idx: int): - return self.outputs[idx] - - def __setitem__(self, idx: int, value): - self.outputs[idx] = value - - def __len__(self): - return len(self.outputs) - - def __eq__(self, other: object): - return isinstance(other, - self.__class__) and self.outputs == other.outputs - - def __repr__(self) -> str: - """Show the shape of a tensor instead of its values to reduce noise. - """ - sampled_token_probs_repr = ("None" if self.sampled_token_probs is None - else self.sampled_token_probs.shape) - sampled_token_ids_repr = ("None" if self.sampled_token_ids is None else - self.sampled_token_ids.shape) - return ( - f"SamplerOutput(outputs={self.outputs}, " - f"sampled_token_probs={sampled_token_probs_repr}, " - f"sampled_token_ids={sampled_token_ids_repr}, " - f"spec_decode_worker_metrics={self.spec_decode_worker_metrics})") - - class PoolerOutput( msgspec.Struct, omit_defaults=True, # type: ignore[call-arg] diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 8a691d65aaa06..b2204e8b27afd 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -5,8 +5,9 @@ import torch from vllm import SamplingParams +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, ExecuteModelRequest, - SamplerOutput, SequenceData, SequenceGroupMetadata, + SequenceData, SequenceGroupMetadata, get_all_seq_ids) from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index aedf0a83da07d..6e35e40294381 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -3,6 +3,7 @@ import torch from vllm import _custom_ops as ops +from vllm.model_executor.layers.sampler import SamplerOutput try: from vllm.attention.backends.flash_attn import FlashAttentionMetadata @@ -16,8 +17,7 @@ PromptAdapterConfig, SchedulerConfig) from vllm.logger import init_logger from vllm.multimodal import MultiModalInputs -from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, - SamplerOutput) +from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.worker.model_runner import (ModelInputForGPUWithSamplingMetadata, ModelRunner) diff --git a/vllm/spec_decode/medusa_worker.py b/vllm/spec_decode/medusa_worker.py index d1809e49c2a8f..0d233f393cb8c 100644 --- a/vllm/spec_decode/medusa_worker.py +++ b/vllm/spec_decode/medusa_worker.py @@ -4,8 +4,8 @@ import torch from vllm.model_executor import SamplingMetadata -from vllm.sequence import (ExecuteModelRequest, SamplerOutput, - SequenceGroupMetadata) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer diff --git a/vllm/spec_decode/mlp_speculator_worker.py b/vllm/spec_decode/mlp_speculator_worker.py index 76e444387816f..fc41bb82ea340 100644 --- a/vllm/spec_decode/mlp_speculator_worker.py +++ b/vllm/spec_decode/mlp_speculator_worker.py @@ -3,8 +3,8 @@ import torch from vllm.model_executor import SamplingMetadata -from vllm.sequence import (ExecuteModelRequest, SamplerOutput, - SequenceGroupMetadata) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 2dfbacfb7b759..4b53fbe056c47 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -4,8 +4,9 @@ import torch -from vllm.sequence import (ExecuteModelRequest, HiddenStates, SamplerOutput, - SequenceData, SequenceGroupMetadata) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import (ExecuteModelRequest, HiddenStates, SequenceData, + SequenceGroupMetadata) from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) diff --git a/vllm/spec_decode/ngram_worker.py b/vllm/spec_decode/ngram_worker.py index 806480b5c892f..36e5e1774aa0d 100644 --- a/vllm/spec_decode/ngram_worker.py +++ b/vllm/spec_decode/ngram_worker.py @@ -3,7 +3,8 @@ import torch -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.proposer_worker_base import NonLLMProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer diff --git a/vllm/spec_decode/proposer_worker_base.py b/vllm/spec_decode/proposer_worker_base.py index efb8ee25ba2f9..28a537593f26d 100644 --- a/vllm/spec_decode/proposer_worker_base.py +++ b/vllm/spec_decode/proposer_worker_base.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod from typing import List, Optional, Set, Tuple -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.interfaces import SpeculativeProposer from vllm.worker.worker_base import LoraNotSupportedWorkerBase diff --git a/vllm/spec_decode/smaller_tp_proposer_worker.py b/vllm/spec_decode/smaller_tp_proposer_worker.py index 215ede52fb812..8896b7dbc6b8a 100644 --- a/vllm/spec_decode/smaller_tp_proposer_worker.py +++ b/vllm/spec_decode/smaller_tp_proposer_worker.py @@ -6,7 +6,8 @@ init_model_parallel_group, patch_tensor_parallel_group) from vllm.logger import init_logger -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.spec_decode.interfaces import SpeculativeProposals from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index 9b1f21fcb4920..91f0a98c7bc38 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -8,12 +8,13 @@ from vllm.distributed.communication_op import broadcast_tensor_dict from vllm.logger import init_logger from vllm.model_executor.layers.rejection_sampler import RejectionSampler +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.layers.spec_decode_base_sampler import ( SpecDecodeBaseSampler, SpecDecodeStochasticBaseSampler) from vllm.model_executor.layers.typical_acceptance_sampler import ( TypicalAcceptanceSampler) from vllm.sequence import (CompletionSequenceGroupOutput, ExecuteModelRequest, - HiddenStates, SamplerOutput, SequenceGroupMetadata, + HiddenStates, SequenceGroupMetadata, get_all_seq_ids, get_all_seq_ids_and_request_ids) from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner @@ -624,8 +625,8 @@ def _verify_tokens( seq_group_metadata_list, proposal_lens_list) original_indices = spec_indices + non_spec_indices - # Get probabilities of target model, excluding bonus token. - proposal_verifier_probs = proposal_scores.probs[spec_indices, :-1] + # Get probabilities of target model, including bonus tokens. + proposal_verifier_probs = proposal_scores.probs[spec_indices] # Get non-speculative sampled tokens from target model. non_spec_token_ids = proposal_scores.token_ids[non_spec_indices] @@ -650,13 +651,12 @@ def _verify_tokens( } accepted_token_ids = self.spec_decode_sampler( - target_probs=proposal_verifier_probs, + target_with_bonus_probs=proposal_verifier_probs, bonus_token_ids=bonus_token_ids, draft_probs=proposal_probs, draft_token_ids=proposal_token_ids, **sampler_extra_kwargs, ) - # Append output tokens from non-speculative sequences to # the accepted token ids tensor. non_spec_token_ids = non_spec_token_ids.expand(-1, max_proposal_len + diff --git a/vllm/spec_decode/top1_proposer.py b/vllm/spec_decode/top1_proposer.py index aa993e539b6d3..f6a52a516075d 100644 --- a/vllm/spec_decode/top1_proposer.py +++ b/vllm/spec_decode/top1_proposer.py @@ -2,8 +2,8 @@ import torch -from vllm.sequence import (ExecuteModelRequest, SamplerOutput, - SequenceGroupMetadata) +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeProposer) from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase diff --git a/vllm/spec_decode/util.py b/vllm/spec_decode/util.py index d18ee47e23a5c..54e718bc49017 100644 --- a/vllm/spec_decode/util.py +++ b/vllm/spec_decode/util.py @@ -4,9 +4,9 @@ import torch +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import (CompletionSequenceGroupOutput, Logprob, - SamplerOutput, SequenceGroupMetadata, - SequenceOutput) + SequenceGroupMetadata, SequenceOutput) SeqId = int @@ -43,8 +43,8 @@ def get_sampled_token_logprobs( sampled_token_ids, ] expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand( -1, -1, vocab_size) - sampled_token_ids_ranks = (logprob_tensor >= - expanded_selected_logprobs).sum(-1) + sampled_token_ids_ranks = (logprob_tensor > + expanded_selected_logprobs).sum(-1).add_(1) return sampled_token_ids_ranks, selected_logprobs diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index c2276b075c1dd..13fcf6b918603 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -1,27 +1,38 @@ import contextlib +import enum +import json from pathlib import Path from typing import Any, Dict, Optional, Type, Union +from huggingface_hub import file_exists, hf_hub_download from transformers import GenerationConfig, PretrainedConfig from transformers.models.auto.image_processing_auto import ( get_image_processor_config) from transformers.models.auto.modeling_auto import ( MODEL_FOR_CAUSAL_LM_MAPPING_NAMES) +from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger +# yapf conflicts with isort for this block +# yapf: disable from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig, - EAGLEConfig, InternVLChatConfig, + EAGLEConfig, ExaoneConfig, + GraniteConfig, InternVLChatConfig, JAISConfig, MedusaConfig, MLPSpeculatorConfig, MPTConfig, NemotronConfig, RWConfig, UltravoxConfig) +# yapf: enable +from vllm.transformers_utils.utils import check_gguf_file if VLLM_USE_MODELSCOPE: from modelscope import AutoConfig else: from transformers import AutoConfig +MISTRAL_CONFIG_NAME = "params.json" + logger = init_logger(__name__) _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = { @@ -34,9 +45,13 @@ "mlp_speculator": MLPSpeculatorConfig, "medusa": MedusaConfig, "eagle": EAGLEConfig, + "exaone": ExaoneConfig, "internvl_chat": InternVLChatConfig, "nemotron": NemotronConfig, "ultravox": UltravoxConfig, + # Granite can be removed from here once we have upgraded to + # transformers 4.45+ + "granite": GraniteConfig, } for name, cls in _CONFIG_REGISTRY.items(): @@ -44,6 +59,20 @@ AutoConfig.register(name, cls) +class ConfigFormat(str, enum.Enum): + AUTO = "auto" + HF = "hf" + MISTRAL = "mistral" + + +def file_or_path_exists(model: Union[str, Path], config_name, revision, + token) -> bool: + if Path(model).exists(): + return (Path(model) / config_name).is_file() + + return file_exists(model, HF_CONFIG_NAME, revision=revision, token=token) + + def get_config( model: Union[str, Path], trust_remote_code: bool, @@ -51,38 +80,68 @@ def get_config( code_revision: Optional[str] = None, rope_scaling: Optional[dict] = None, rope_theta: Optional[float] = None, + config_format: ConfigFormat = ConfigFormat.AUTO, **kwargs, ) -> PretrainedConfig: - # Separate model folder from file path for GGUF models - is_gguf = Path(model).is_file() and Path(model).suffix == ".gguf" + + is_gguf = check_gguf_file(model) if is_gguf: kwargs["gguf_file"] = Path(model).name model = Path(model).parent - try: - config = AutoConfig.from_pretrained( - model, - trust_remote_code=trust_remote_code, - revision=revision, - code_revision=code_revision, - **kwargs) - except ValueError as e: - if (not trust_remote_code and - "requires you to execute the configuration file" in str(e)): - err_msg = ( - "Failed to load the model config. If the model is a custom " - "model not yet available in the HuggingFace transformers " - "library, consider setting `trust_remote_code=True` in LLM " - "or using the `--trust-remote-code` flag in the CLI.") - raise RuntimeError(err_msg) from e + if config_format == ConfigFormat.AUTO: + if is_gguf or file_or_path_exists(model, + HF_CONFIG_NAME, + revision=revision, + token=kwargs.get("token")): + config_format = ConfigFormat.HF + elif file_or_path_exists(model, + MISTRAL_CONFIG_NAME, + revision=revision, + token=kwargs.get("token")): + config_format = ConfigFormat.MISTRAL else: - raise e - if config.model_type in _CONFIG_REGISTRY: - config_class = _CONFIG_REGISTRY[config.model_type] - config = config_class.from_pretrained(model, - revision=revision, - code_revision=code_revision) + raise ValueError(f"No supported config format found in {model}") + + if config_format == ConfigFormat.HF: + config_dict, _ = PretrainedConfig.get_config_dict( + model, revision=revision, code_revision=code_revision, **kwargs) + + # Use custom model class if it's in our registry + model_type = config_dict.get("model_type") + if model_type in _CONFIG_REGISTRY: + config_class = _CONFIG_REGISTRY[model_type] + config = config_class.from_pretrained(model, + revision=revision, + code_revision=code_revision) + else: + try: + config = AutoConfig.from_pretrained( + model, + trust_remote_code=trust_remote_code, + revision=revision, + code_revision=code_revision, + **kwargs, + ) + except ValueError as e: + if (not trust_remote_code + and "requires you to execute the configuration file" + in str(e)): + err_msg = ( + "Failed to load the model config. If the model " + "is a custom model not yet available in the " + "HuggingFace transformers library, consider setting " + "`trust_remote_code=True` in LLM or using the " + "`--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + + elif config_format == ConfigFormat.MISTRAL: + config = load_params_config(model, revision) + else: + raise ValueError(f"Unsupported config format: {config_format}") # Special architecture mapping check for GGUF models if is_gguf: @@ -92,30 +151,87 @@ def get_config( model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type] config.update({"architectures": [model_type]}) - for key, value in [("rope_scaling", rope_scaling), - ("rope_theta", rope_theta)]: + for key, value in [ + ("rope_scaling", rope_scaling), + ("rope_theta", rope_theta), + ]: if value is not None: - logger.info("Updating %s from %r to %r", key, - getattr(config, key, None), value) + logger.info( + "Updating %s from %r to %r", + key, + getattr(config, key, None), + value, + ) config.update({key: value}) return config +def load_params_config(model, revision) -> PretrainedConfig: + # This function loads a params.json config which + # should be used when loading models in mistral format + + config_file_name = "params.json" + + config_path = Path(model) / config_file_name + + if not config_path.is_file(): + config_path = Path( + hf_hub_download(model, config_file_name, revision=revision)) + + with open(config_path, "r") as file: + config_dict = json.load(file) + + config_mapping = { + "dim": "hidden_size", + "norm_eps": "rms_norm_eps", + "n_kv_heads": "num_key_value_heads", + "n_layers": "num_hidden_layers", + "n_heads": "num_attention_heads", + "hidden_dim": "intermediate_size", + } + + def recurse_elems(elem: Any): + if isinstance(elem, dict): + config_dict = {} + for key, value in elem.items(): + key = config_mapping.get(key, key) + config_dict[key] = recurse_elems(value) + return PretrainedConfig(**config_dict) + else: + return elem + + config_dict["model_type"] = config_dict.get("model_type", "transformer") + config_dict["hidden_act"] = config_dict.get("activation", "silu") + config_dict["tie_word_embeddings"] = config_dict.get( + "tie_embeddings", False) + + if config_dict["model_type"] == "transformer": + if "moe" in config_dict: + config_dict["architectures"] = ["MixtralForCausalLM"] + else: + config_dict["architectures"] = ["MistralForCausalLM"] + + return recurse_elems(config_dict) + + def get_hf_image_processor_config( model: Union[str, Path], revision: Optional[str] = None, **kwargs, ) -> Dict[str, Any]: + # ModelScope does not provide an interface for image_processor + if VLLM_USE_MODELSCOPE: + return dict() # Separate model folder from file path for GGUF models - if Path(model).is_file() and Path(model).suffix == ".gguf": + if check_gguf_file(model): model = Path(model).parent return get_image_processor_config(model, revision=revision, **kwargs) def get_hf_text_config(config: PretrainedConfig): """Get the "sub" config relevant to llm for multi modal models. - No op for pure text models. + No op for pure text models. """ if hasattr(config, "text_config"): # The code operates under the assumption that text_config should have diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index dc2fd6a859e3c..8381c5227584e 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -1,10 +1,12 @@ from vllm.transformers_utils.configs.chatglm import ChatGLMConfig from vllm.transformers_utils.configs.dbrx import DbrxConfig from vllm.transformers_utils.configs.eagle import EAGLEConfig +from vllm.transformers_utils.configs.exaone import ExaoneConfig # RWConfig is for the original tiiuae/falcon-40b(-instruct) and # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # `FalconConfig` class from the official HuggingFace transformers library. from vllm.transformers_utils.configs.falcon import RWConfig +from vllm.transformers_utils.configs.granite import GraniteConfig from vllm.transformers_utils.configs.internvl import InternVLChatConfig from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.medusa import MedusaConfig @@ -22,7 +24,11 @@ "JAISConfig", "MedusaConfig", "EAGLEConfig", + "ExaoneConfig", "MLPSpeculatorConfig", "NemotronConfig", "UltravoxConfig", + # Granite can be removed from here once we have upgraded to + # transformers 4.45+ + "GraniteConfig", ] diff --git a/vllm/transformers_utils/configs/exaone.py b/vllm/transformers_utils/configs/exaone.py new file mode 100644 index 0000000000000..805b8ad930039 --- /dev/null +++ b/vllm/transformers_utils/configs/exaone.py @@ -0,0 +1,190 @@ +# coding=utf-8 +# Copied from +# https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/blob/main/configuration_exaone.py +# Copyright 2021 The LG AI Research EXAONE Lab. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Exaone model configuration""" + +from typing import Dict + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +EXAONE_PRETRAINED_CONFIG_ARCHIVE_MAP: Dict[str, str] = {} + + +class ExaoneConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a :class: + `~transformers.ExaoneModel`. It is used to instantiate a GPT Lingvo model + according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar + configuration to that of the Exaone + + Configuration objects inherit from :class:`~transformers.PretrainedConfig` + and can be used to control the model outputs. Read the documentation from : + class:`~transformers.PretrainedConfig` for more information. + + Args: + vocab_size (:obj:`int`, `optional`, defaults to 50257): + Vocabulary size of the GPT Lingvo model. Defines the number of + different tokens that can be represented by the :obj:`inputs_ids` + passed when calling :class:`~transformers.ExaoneModel`. Vocabulary + size of the model. + Defines the different tokens that can be represented by the + `inputs_ids` passed to the forward method of :class: + `~transformers.EXAONEModel`. + hidden_size (:obj:`int`, `optional`, defaults to 2048): + Dimensionality of the encoder layers and the pooler layer. + num_layers (:obj:`int`, `optional`, defaults to 24): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the + Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to + implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi + Head Attention (MHA), if `num_key_value_heads=1 the model will use + Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, + each group key and value head should be constructed by meanpooling + all the original heads within that group. For more details checkout + [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not + specified, will default to `num_attention_heads`. + rotary_pct (`float`, *optional*, defaults to 0.25): + percentage of hidden dimensions to allocate to rotary embeddings + intermediate_size (:obj:`int`, `optional`, defaults to 8192): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in + the Transformer encoder. + activation_function (:obj:`str` or :obj:`function`, `optional`, + defaults to :obj:`"gelu_new"`): + The non-linear activation function (function or string) in the + encoder and pooler. If string, :obj:`"gelu"`, :obj:`"relu"`, + :obj:`"selu"` and :obj:`"gelu_new"` are supported. + embed_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout probabilitiy for all fully connected layers in the + embeddings, encoder, and pooler. + attention_dropout (:obj:`float`, `optional`, defaults to 0.0): + The dropout ratio for the attention probabilities. + max_position_embeddings (:obj:`int`, `optional`, defaults to 2048): + The maximum sequence length that this model might ever be used with. + Typically set this to something large just in case + (e.g., 512 or 1024 or 2048). + type_vocab_size (:obj:`int`, `optional`, defaults to 2): + The vocabulary size of the :obj:`token_type_ids` passed when calling + :class:`~transformers.EXAONEModel`. + initializer_range (:obj:`float`, `optional`, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. + layer_norm_epsilon (:obj:`float`, `optional`, defaults to 1e-5): + The epsilon used by the layer normalization layers. + use_cache (:obj:`bool`, `optional`, defaults to :obj:`True`): + Whether or not the model should return the last key/values + attentions (not used by all models). + Only relevant if ``config.is_decoder=True``. + gradient_checkpointing (:obj:`bool`, `optional`, + defaults to :obj:`False`): + If True, use gradient checkpointing to save memory at the expense + of slower backward pass. + Example:: + + >>> from transformers import ExoneModel, ExaoneConfig + + >>> # Initializing a EXAONE configuration + >>> configuration = ExaoneConfig() + + >>> # Initializing a model from configuration + >>> model = ExoneModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + """ + + model_type = "exaone" + keys_to_ignore_at_inference = ["past_key_values"] + attribute_map = {"num_hidden_layers": "num_layers"} + + def __init__( + self, + vocab_size=102400, + max_position_embeddings=2048, + hidden_size=2048, + num_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + intermediate_size=None, + activation_function="silu", + rotary_pct=0.25, + resid_dropout=0.0, + embed_dropout=0.0, + attention_dropout=0.0, + layer_norm_epsilon=1e-6, + initializer_range=0.02, + use_cache=True, + bos_token_id=0, + eos_token_id=2, + tie_word_embeddings=True, + **kwargs, + ): + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.num_layers = num_layers + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_layers + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + self.num_key_value_heads = num_key_value_heads + if intermediate_size: + self.intermediate_size = intermediate_size + else: + self.intermediate_size = hidden_size * 4 + self.activation_function = activation_function + self.resid_dropout = resid_dropout + self.embed_dropout = embed_dropout + self.attention_dropout = attention_dropout + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.use_cache = use_cache + self.rotary_pct = rotary_pct + + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + + self.use_logit_cap = kwargs.pop("use_logit_cap", False) + self.ln_no_scale = kwargs.pop("ln_no_scale", False) + self.use_gated = kwargs.pop("use_gated", False) + self.use_emb_norm = kwargs.pop("use_emb_norm", False) + self.use_rotary_pos = kwargs.pop("use_rotary_pos", False) + self.rotary_type = kwargs.pop("rotary_type", None) + self.scaling_factor = kwargs.pop("scaling_factor", 1) + self.use_absolute_pos = kwargs.pop("use_absolute_pos", True) + self.use_extra_logit = kwargs.pop("use_extra_logit", True) + self.rotary_expand_length = kwargs.pop("rotary_expand_length", None) + self.rotary_base = kwargs.pop("rotary_base", 10000.0) + self.use_qkv_fuse = kwargs.pop("use_qkv_fuse", False) + self.rescale_before_lm_head = kwargs.pop("rescale_before_lm_head", + (rotary_pct == 0.25)) + if self.use_rotary_pos: + self.use_absolute_pos = False diff --git a/vllm/transformers_utils/configs/granite.py b/vllm/transformers_utils/configs/granite.py new file mode 100644 index 0000000000000..c12838be5d385 --- /dev/null +++ b/vllm/transformers_utils/configs/granite.py @@ -0,0 +1,199 @@ +# coding=utf-8 +# Copyright 2024 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Granite model configuration""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import rope_config_validation +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class GraniteConfig(PretrainedConfig): + r""" + This is the configuration class to store the configuration of + a [`GraniteModel`]. It is used to instantiate an Granite + model according to the specified arguments, defining the model architecture. + Instantiating a configuration with the defaults will yield a similar + configuration to that of the Granite-3B. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to + control the model outputs. Read the documentation from [`PretrainedConfig`] + for more information. + + + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Granite model. Defines the number of + different tokens that can be represented by the `inputs_ids` + passed when calling [`GraniteModel`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the + Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to + implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi + Head Attention (MHA), if `num_key_value_heads=1` the model will use + Multi Query Attention (MQA) otherwise GQA is used. When converting + a multi-head checkpoint to a GQA checkpoint, each group key and + value head should be constructed by meanpooling all the original + heads within that group. For more details checkout + [this paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not + specified, will default to `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the + decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for + initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values + attentions (not used by all models). Only relevant if + `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE + embeddings. Currently supports two scaling strategies: linear and + dynamic. Their scaling factor must be a float greater than 1. The + expected format is + `{"type": strategy name, "factor": scaling factor}`. + When using this flag, don't update `max_position_embeddings` to + the expected new maximum. See the following thread for more + information on how these scaling strategies behave: + https://www.reddit.com/r/LocalLLaMA/comments/14mrgpr/dynamically_scaled_rope_further_increases/. + This is an experimental feature, subject to breaking API changes + in future versions. + attention_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output + projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + mlp_bias (`bool`, *optional*, defaults to `False`): + Whether to use a bias in up_proj, down_proj and gate_proj layers + in the MLP layers. + embedding_multiplier (`float`, *optional*, defaults to 1.0): + embedding multiplier + logits_scaling (`float`, *optional*, defaults to 1.0): + divisor for output logits + residual_multiplier (`float`, *optional*, defaults to 1.0): + residual multiplier + attention_multiplier (`float`, *optional*, defaults to 1.0): + attention multiplier + + ```python + >>> from transformers import GraniteModel, GraniteConfig + + >>> # Initializing a Granite granite-3b style configuration + >>> configuration = GraniteConfig() + + >>> # Initializing a model from the granite-7b style configuration + >>> model = GraniteModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "granite" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=32000, + hidden_size=4096, + intermediate_size=11008, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + hidden_act="silu", + max_position_embeddings=2048, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=1, + eos_token_id=2, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + mlp_bias=False, + embedding_multiplier=1.0, + logits_scaling=1.0, + residual_multiplier=1.0, + attention_multiplier=1.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + self.mlp_bias = mlp_bias + + self.embedding_multiplier = embedding_multiplier + self.logits_scaling = logits_scaling + self.residual_multiplier = residual_multiplier + self.attention_multiplier = attention_multiplier + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + rope_config_validation(self) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 2866975850db3..f9fb8d1e103b7 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -12,6 +12,7 @@ from vllm.lora.request import LoRARequest from vllm.transformers_utils.tokenizers import (BaichuanTokenizer, MistralTokenizer) +from vllm.transformers_utils.utils import check_gguf_file from vllm.utils import make_async logger = init_logger(__name__) @@ -96,8 +97,7 @@ def get_tokenizer( kwargs["truncation_side"] = "left" # Separate model folder from file path for GGUF models - is_gguf = Path(tokenizer_name).is_file() and Path( - tokenizer_name).suffix == ".gguf" + is_gguf = check_gguf_file(tokenizer_name) if is_gguf: kwargs["gguf_file"] = Path(tokenizer_name).name tokenizer_name = Path(tokenizer_name).parent diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py index 23ecfc0af6be4..533a86b787325 100644 --- a/vllm/transformers_utils/tokenizers/mistral.py +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -52,12 +52,13 @@ def __init__(self, tokenizer: PublicMistralTokenizer) -> None: assert isinstance(self.tokenizer, (Tekkenizer, SentencePieceTokenizer)), type( self.tokenizer) - self._is_tekken = isinstance(self.tokenizer, Tekkenizer) - if self._is_tekken: + if (is_tekken := isinstance(self.tokenizer, Tekkenizer)): # Make sure special tokens will not raise self.tokenizer.special_token_policy = SpecialTokenPolicy.IGNORE + self._is_tekken = is_tekken + # the following attributes are set to fit VLLM's design self.is_fast = True self.chat_template = True diff --git a/vllm/transformers_utils/utils.py b/vllm/transformers_utils/utils.py new file mode 100644 index 0000000000000..7a9041b04fbb9 --- /dev/null +++ b/vllm/transformers_utils/utils.py @@ -0,0 +1,16 @@ +from os import PathLike +from pathlib import Path +from typing import Union + + +def check_gguf_file(model: Union[str, PathLike]) -> bool: + """Check if the file is a GGUF model.""" + model = Path(model) + if not model.is_file(): + return False + elif model.suffix == ".gguf": + return True + + with open(model, "rb") as f: + header = f.read(4) + return header == b"GGUF" diff --git a/vllm/utils.py b/vllm/utils.py index dab8e5fe04359..a22081ebe8df0 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -25,6 +25,7 @@ import psutil import torch import torch.types +import yaml from packaging.version import Version from typing_extensions import ParamSpec, TypeIs, assert_never @@ -1093,6 +1094,9 @@ def parse_args(self, args=None, namespace=None): if args is None: args = sys.argv[1:] + if '--config' in args: + args = FlexibleArgumentParser._pull_args_from_config(args) + # Convert underscores to dashes and vice versa in argument names processed_args = [] for arg in args: @@ -1109,6 +1113,103 @@ def parse_args(self, args=None, namespace=None): return super().parse_args(processed_args, namespace) + @staticmethod + def _pull_args_from_config(args: List[str]) -> List[str]: + """Method to pull arguments specified in the config file + into the command-line args variable. + + The arguments in config file will be inserted between + the argument list. + + example: + ```yaml + port: 12323 + tensor-parallel-size: 4 + ``` + ```python + $: vllm {serve,chat,complete} "facebook/opt-12B" \ + --config config.yaml -tp 2 + $: args = [ + "serve,chat,complete", + "facebook/opt-12B", + '--config', 'config.yaml', + '-tp', '2' + ] + $: args = [ + "serve,chat,complete", + "facebook/opt-12B", + '--port', '12323', + '--tensor-parallel-size', '4', + '-tp', '2' + ] + ``` + + Please note how the config args are inserted after the sub command. + this way the order of priorities is maintained when these are args + parsed by super(). + """ + assert args.count( + '--config') <= 1, "More than one config file specified!" + + index = args.index('--config') + if index == len(args) - 1: + raise ValueError("No config file specified! \ + Please check your command-line arguments.") + + file_path = args[index + 1] + + config_args = FlexibleArgumentParser._load_config_file(file_path) + + # 0th index is for {serve,chat,complete} + # followed by config args + # followed by rest of cli args. + # maintaining this order will enforce the precedence + # of cli > config > defaults + args = [args[0]] + config_args + args[1:index] + args[index + 2:] + + return args + + @staticmethod + def _load_config_file(file_path: str) -> List[str]: + """Loads a yaml file and returns the key value pairs as a + flattened list with argparse like pattern + ```yaml + port: 12323 + tensor-parallel-size: 4 + ``` + returns: + processed_args: list[str] = [ + '--port': '12323', + '--tensor-parallel-size': '4' + ] + + """ + + extension: str = file_path.split('.')[-1] + if extension not in ('yaml', 'yml'): + raise ValueError( + "Config file must be of a yaml/yml type.\ + %s supplied", extension) + + # only expecting a flat dictionary of atomic types + processed_args: List[str] = [] + + config: Dict[str, Union[int, str]] = {} + try: + with open(file_path, 'r') as config_file: + config = yaml.safe_load(config_file) + except Exception as ex: + logger.error( + "Unable to read the config file at %s. \ + Make sure path is correct", file_path) + raise ex + + for key, value in config.items(): + processed_args.append('--' + key) + processed_args.append(str(value)) + + return processed_args + async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, **kwargs): @@ -1123,3 +1224,28 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, def supports_dynamo() -> bool: base_torch_version = Version(Version(torch.__version__).base_version) return base_torch_version >= Version("2.4.0") + + +class AtomicCounter: + """An atomic, thread-safe counter""" + + def __init__(self, initial=0): + """Initialize a new atomic counter to given initial value""" + self._value = initial + self._lock = threading.Lock() + + def inc(self, num=1): + """Atomically increment the counter by num and return the new value""" + with self._lock: + self._value += num + return self._value + + def dec(self, num=1): + """Atomically decrement the counter by num and return the new value""" + with self._lock: + self._value -= num + return self._value + + @property + def value(self): + return self._value diff --git a/vllm/version.py b/vllm/version.py index 052eb76b5873c..039f6369b8ed5 100644 --- a/vllm/version.py +++ b/vllm/version.py @@ -9,4 +9,4 @@ stacklevel=2) __commit__ = "COMMIT_HASH_PLACEHOLDER" -__version__ = "0.5.5" +__version__ = "0.6.0" diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index f69afa4c43149..7205b1a7beb8d 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -10,11 +10,11 @@ SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 5c700229660c0..d6189d82d51d9 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -16,9 +16,10 @@ from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.sampling_params import SamplingParams -from vllm.sequence import (IntermediateTensors, PoolerOutput, SamplerOutput, +from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceGroupMetadata) from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad from vllm.worker.model_runner import (GPUModelRunnerBase, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index f556e4ea117ae..74f7d4e0860d3 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -21,6 +21,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) +from vllm.core.scheduler import SchedulerOutputs from vllm.distributed import get_pp_group from vllm.distributed.parallel_state import graph_capture from vllm.inputs import INPUT_REGISTRY, InputRegistry @@ -29,6 +30,7 @@ from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.model_executor import SamplingMetadata, SamplingMetadataCache +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.model_executor.models.interfaces import (supports_lora, @@ -41,8 +43,7 @@ from vllm.prompt_adapter.worker_manager import ( LRUCacheWorkerPromptAdapterManager) from vllm.sampling_params import SamplingParams -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d, flatten_2d_lists, is_hip, is_pin_memory_available, supports_dynamo) @@ -60,10 +61,14 @@ LORA_WARMUP_RANK = 8 _BATCH_SIZE_ALIGNMENT = 8 -# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. +# all the token sizes that **can** be captured by cudagraph. +# they can be arbitrarily large. +# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192. +# the actual sizes to capture will be determined by the model, +# depending on the model's max_num_seqs. # NOTE: _get_graph_batch_size needs to be updated if this list is changed. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ - _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) + _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025) ] _NUM_WARMUP_ITERS = 2 @@ -92,6 +97,8 @@ class ModelInputForGPU(ModelRunnerInputBase): finished_requests_ids: Optional[List[str]] = None virtual_engine: int = 0 async_callback: Optional[Callable] = None + seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None + scheduler_outputs: Optional[SchedulerOutputs] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -501,23 +508,48 @@ def _compute_for_prefix_cache_hit( and self.sliding_window is None and inter_data.is_prompt) inter_data.prefix_cache_hit = prefix_cache_hit - if self.chunked_prefill_enabled and prefix_cache_hit: - raise RuntimeError( - "chunked prefill cannot be used with prefix caching now.") - - # If prefix cache is hit, advance context length to bypass - # hit blocks. Accordingly, input tokens, position and query length - # have to be updated. - if prefix_cache_hit: - assert computed_block_nums is not None - context_len = len(computed_block_nums) * self.block_size + + if not prefix_cache_hit: + return + + assert computed_block_nums is not None + # The cache hit prompt tokens in this sequence. Note that + # this may be larger than the sequence length if chunked + # prefill is enabled. + prefix_cache_len = len(computed_block_nums) * self.block_size + # The number of so far computed prompt tokens in this sequence. + context_len = inter_data.context_lens[seq_idx] + # The total number of prompt tokens in this sequence. + # When chunked prefill is enabled, this is the token number of + # computed chunks + current chunk. + seq_len = inter_data.seq_lens[seq_idx] + if prefix_cache_len <= context_len: + # We already passed the cache hit region, + # so do normal computation. + pass + elif context_len < prefix_cache_len < seq_len: + # Partial hit. Compute the missing part. + uncomputed_start = prefix_cache_len - context_len inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ - seq_idx][context_len:] + seq_idx][uncomputed_start:] inter_data.input_positions[seq_idx] = inter_data.input_positions[ - seq_idx][context_len:] + seq_idx][uncomputed_start:] + context_len = prefix_cache_len + inter_data.context_lens[seq_idx] = context_len inter_data.query_lens[ seq_idx] = inter_data.seq_lens[seq_idx] - context_len + elif seq_len <= prefix_cache_len: + # Full hit. Only compute the last token to avoid + # erroneous behavior. FIXME: Ideally we should directly + # mark all tokens as computed in the scheduler and do not + # schedule this sequence, so this case should not happen. + inter_data.input_tokens[seq_idx] = inter_data.input_tokens[ + seq_idx][-1:] + inter_data.input_positions[seq_idx] = inter_data.input_positions[ + seq_idx][-1:] + inter_data.query_lens[seq_idx] = 1 + inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1 def _compute_for_sliding_window(self, inter_data: InterDataForSeqGroup, seq_idx: int, @@ -634,7 +666,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): def _use_captured_graph(self, batch_size: int, max_decode_seq_len: int) -> bool: return (self.decode_only and not self.runner.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and batch_size <= self.runner.max_batchsize_to_capture and max_decode_seq_len <= self.runner.max_seq_len_to_capture) def build(self) -> ModelInputForGPU: @@ -820,6 +852,8 @@ def __init__( self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture + self.max_batchsize_to_capture = _get_max_graph_batch_size( + self.scheduler_config.max_num_seqs) self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ {} for _ in range(self.parallel_config.pipeline_parallel_size) @@ -837,7 +871,7 @@ def __init__( # The shape of the cached block table will be # (max batch size to capture, max context len to capture / block size). self.graph_block_tables = np.zeros( - (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), + (self.max_batchsize_to_capture, self.get_max_block_per_batch()), dtype=np.int32) num_attn_heads = self.model_config.get_num_attention_heads( self.parallel_config) @@ -1098,10 +1132,6 @@ def profile_run(self) -> None: device=self.device) self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() - - # reset and discard the guard and compiled bytecode for profiling runs - torch._dynamo.reset() - return def remove_all_loras(self): @@ -1196,7 +1226,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: start_time = time.perf_counter() # Prepare dummy inputs. These will be reused for all batch sizes. - max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) + max_batch_size = self.max_batchsize_to_capture input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() @@ -1224,8 +1254,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: None ] * self.parallel_config.pipeline_parallel_size - graph_batch_size = _get_graph_batch_size( - self.scheduler_config.max_num_seqs) + graph_batch_size = self.max_batchsize_to_capture batch_size_capture_list = [ bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] @@ -1651,3 +1680,22 @@ def _get_graph_batch_size(batch_size: int) -> int: else: return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) + + +def _get_max_graph_batch_size(max_num_seqs: int) -> int: + """ + max_num_seqs: Maximum number of sequences in a batch. + _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture. + + pad the max_num_seqs if necessary by calling _get_graph_batch_size, + which will deal with some edge cases like 1, 2, 4. + + if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size. + if not, it means the padded size is larger than the largest size in + _BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE. + """ + padded_size = _get_graph_batch_size(max_num_seqs) + if padded_size in _BATCH_SIZES_TO_CAPTURE: + return padded_size + assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1] + return _BATCH_SIZES_TO_CAPTURE[-1] diff --git a/vllm/worker/model_runner_base.py b/vllm/worker/model_runner_base.py index 90c39407d7266..f8fd9d801d289 100644 --- a/vllm/worker/model_runner_base.py +++ b/vllm/worker/model_runner_base.py @@ -5,9 +5,9 @@ import torch +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata if TYPE_CHECKING: from vllm.attention import AttentionMetadata diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 521205eca05af..b13cf39bd846e 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -1,5 +1,8 @@ +import dataclasses +import functools from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Union) try: from vllm.attention.backends.flash_attn import FlashAttentionMetadata @@ -13,9 +16,13 @@ from vllm import _custom_ops as ops from vllm.distributed import get_pp_group from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import (PromptLogprobs, SampleLogprobs, + SamplerOutput, + SamplingMetadata, get_logprobs, + get_pythonized_sample_results) from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, - Logprob, SamplerOutput, SequenceGroupMetadata, - SequenceOutput) + Logprob, SequenceGroupMetadata, SequenceOutput) +from vllm.utils import PyObjectCache from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUWithSamplingMetadata) from vllm.worker.model_runner_base import ( @@ -31,6 +38,29 @@ logger = init_logger(__name__) +def seq_output_builder(): + return SequenceOutput( + 0, 0, + {0: Logprob(logprob=float('inf'), rank=None, decoded_token=None)}) + + +def completion_seq_group_output_builder(): + return CompletionSequenceGroupOutput([], None) + + +# Used by pythonization to reduce python object allocations +class PythonizationCache: + + def __init__(self): + self.cached_seq_output = PyObjectCache(seq_output_builder) + self.cached_completion_seq_group_output = PyObjectCache( + completion_seq_group_output_builder) + + def reset(self): + self.cached_seq_output.reset() + self.cached_completion_seq_group_output.reset() + + @dataclass class ModelOutput: """The output of a single model forward pass. @@ -51,6 +81,9 @@ class ModelOutput: sampler_output_ready_event: torch.cuda.Event sampled_token_ids: Optional[torch.Tensor] = None pythonized: bool = False + # On-device tensor containing the logprobs of each token. + logprobs: Optional["torch.Tensor"] = None + pythonization_cache: Optional[PythonizationCache] = None def pythonize(self, input_metadata: "StatefulModelInput", copy_stream: torch.cuda.Stream, @@ -76,7 +109,9 @@ def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput", blocking: bool) -> bool: """ If blocking is set, will block until the forward pass for the output is - ready and pythonize the output. + ready and pythonize the output. Upon completing Pythonization, erases + self.logprobs (note that a non-blocking call that is performed when + the sampler output is not yet ready, will not erase self.logprobs.) """ assert self.sampled_token_ids is not None if not blocking and not self.sampler_output_ready_event.query(): @@ -87,7 +122,16 @@ def _pythonize_sampler_output(self, input_metadata: "StatefulModelInput", with torch.cuda.stream(copy_stream): _pythonize_sampler_output(input_metadata, self.sampler_output, pinned_sampled_token_buffer, - self.sampled_token_ids) + self.sampled_token_ids, self.logprobs, + self.pythonization_cache) + + # Erase the logprobs GPU-side tensor. + # Note that although _pythonize_sampler_output() runs in its + # own CUDA stream, nonetheless _pythonize_sampler_output() + # cannot return until Pythonization is complete; therefore + # we know that by the time the CPU reaches this point, + # `self.logprobs` is no longer needed. + self.logprobs = None return True @@ -191,6 +235,8 @@ def __init__(self, base_model_runner: GPUModelRunnerBase, *args, **kwargs): self._copy_stream = torch.cuda.Stream() self.pinned_sampled_token_ids: Optional[torch.Tensor] = None + self.pythonization_cache = PythonizationCache() + def make_model_input_from_broadcasted_tensor_dict( self, tensor_dict: Dict[str, Any]) -> StatefulModelInput: model_input = (StatefulModelInput.from_broadcasted_tensor_dict( @@ -215,6 +261,81 @@ def prepare_model_input( ) return model_input + def _async_process_outputs(self, model_input: StatefulModelInput, + output_proc_callback: Callable): + # Proceed with pythonization and output_proc in order. + # Stop on the first one that fails to pythonize + output_proc_callback() + + cont = True + for model_output in model_input.cached_outputs: + if not model_output.pythonized: + model_output.maybe_pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + if model_output.pythonized: + ctx = output_proc_callback.keywords["ctx"] + ctx.append_output( + outputs=[model_output.sampler_output], + seq_group_metadata_list=ctx.seq_group_metadata_list, + scheduler_outputs=ctx.scheduler_outputs, + is_async=False, + is_last_step=False) + + output_proc_callback() + else: + cont = False + + if not cont: + break + + def _final_process_outputs(self, model_input: StatefulModelInput, + output_proc_callback: Optional[Callable]): + assert model_input.frozen_model_input is not None + + has_async_callback = output_proc_callback is not None + + outputs = [] + for output_id in range(len(model_input.cached_outputs)): + output = model_input.cached_outputs[output_id] + is_last_step = output_id == len(model_input.cached_outputs) - 1 + + # For non-async case: + # -- We simply add the outputs + # For async case: + # -- Invoke callback, pythonize, add to callback queue and repeat + # -- For last output, just add to callback queue + if has_async_callback: + assert output_proc_callback is not None + + # Invoke callback before pythonize (to overlap with GPU) + output_proc_callback() + + # Pythonize + if not output.pythonized: + output.pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + + # For non last step, add to callback queue to chain + # callbacks=>pythonize pairs (for GPU overlap) + if not is_last_step: + ctx = output_proc_callback.keywords[ # type: ignore + "ctx"] # type: ignore + ctx.append_output( + outputs=[output.sampler_output], + seq_group_metadata_list=ctx. + seq_group_metadata_list, + scheduler_outputs=ctx.scheduler_outputs, + is_async=False, + is_last_step=False) + else: + outputs.append(output.sampler_output) + else: + output.pythonize(model_input, self._copy_stream, + self.pinned_sampled_token_ids) + outputs.append(output.sampler_output) + + return outputs + @torch.inference_mode() def execute_model( self, @@ -271,6 +392,20 @@ def execute_model( model_input = self._advance_step( model_input, model_input.cached_outputs[-1].sampler_output) + output_proc_callback = None + if frozen_model_input.async_callback is not None: + output_proc_callback = frozen_model_input.async_callback + assert output_proc_callback is not None + async_callback = functools.partial( + self._async_process_outputs, + model_input=model_input, + output_proc_callback=output_proc_callback) + + frozen_model_input = dataclasses.replace( # type: ignore + model_input.frozen_model_input, + async_callback=async_callback) + assert frozen_model_input is not None + # Execute the model output = self._base_model_runner.execute_model(frozen_model_input, kv_caches, @@ -294,16 +429,23 @@ def execute_model( 0].sampled_token_ids.cpu() model_input.cached_outputs.append( ModelOutput(output[0], output_ready_event, - output[0].sampled_token_ids, False)) - # make sure we dont try to serialize any GPU tensors + output[0].sampled_token_ids, False, + output[0].logprobs, self.pythonization_cache)) + + # These GPU tensors are not required by multi-step; + # erase them to ensure they are not pythonized or + # transferred to CPU output[0].sampled_token_ids = None output[0].sampled_token_probs = None output[0].logprobs = None + # Pythonize the output if CPU is ahead and the previous step is # ready. - for model_output in model_input.cached_outputs: - model_output.maybe_pythonize(model_input, self._copy_stream, - self.pinned_sampled_token_ids) + if frozen_model_input.async_callback is None: + for model_output in model_input.cached_outputs: + model_output.maybe_pythonize(model_input, + self._copy_stream, + self.pinned_sampled_token_ids) model_input.current_step += 1 @@ -316,11 +458,9 @@ def execute_model( # Pythonize the output and block if needed since it is the last step if model_input.is_last_step: - outputs = [] - for output in model_input.cached_outputs: - output.pythonize(model_input, self._copy_stream, - self.pinned_sampled_token_ids) - outputs.append(output.sampler_output) + outputs = self._final_process_outputs(model_input, + output_proc_callback) + self.pythonization_cache.reset() return outputs # should be [SamplerOutput] @@ -409,12 +549,76 @@ def vocab_size(self) -> int: return self._base_model_runner.vocab_size -def _pythonize_sampler_output(model_input: StatefulModelInput, - output: SamplerOutput, - pinned_sampled_token_buffer: torch.Tensor, - sampled_token_ids: torch.Tensor) -> None: +DeferredLogprobsReturnType = Tuple[Optional[List[Optional[PromptLogprobs]]], + Optional[List[SampleLogprobs]]] + + +def deferred_pythonize_logprobs( + output: SamplerOutput, + sampling_metadata: SamplingMetadata, + logprobs_tensor: Optional[torch.Tensor], +) -> DeferredLogprobsReturnType: + """Perform deferred logprob Pythonization. + + 1. Pythonize GPU-side sampler result tensors into CPU-side sampler result. + 2. Pythonize GPU-side logprobs tensor into CPU-side logprobs lists, + utilizing the Pythonized sampler result computed in step 1. + + These deferred computations are not required for single-step scheduling + or the `profile_run()` phase of multi-step scheduling. + + Args: + output: sampler output (under deferred Pythonization) + sampling_metadata + + Returns: + prompt_logprobs (CPU), sample_logprobs (CPU) + """ + + # - Deferred pythonization of sample result + sampler_result = get_pythonized_sample_results( + output.deferred_sample_results_args) + + # - Erase the GPU-side deferred sample_result + # computation args to ensure it is never + # pythonized or transferred to CPU + output.deferred_sample_results_args = None + + # - Deferred pythonization of logprobs + ( + prompt_logprobs, + sample_logprobs, + ) = get_logprobs(logprobs_tensor, sampling_metadata, sampler_result) + assert len(prompt_logprobs) == len(sampling_metadata.seq_groups) + assert len(sample_logprobs) == len(sampling_metadata.seq_groups) + + return prompt_logprobs, sample_logprobs + + +def _pythonize_sampler_output( + model_input: StatefulModelInput, + output: SamplerOutput, + pinned_sampled_token_buffer: torch.Tensor, + sampled_token_ids: torch.Tensor, + logprobs_tensor: Optional[torch.Tensor], + cache: Optional[PythonizationCache], +) -> None: """ This function is only called when the output tensors are ready. - See ModelOutput + See :class:`ModelOutput`. + + Modifies `output.outputs` and `pinned_sampled_token_buffer` in-place, + adding a Pythonized output data structure + (:class:`CompletionSequenceGroupOutput`) for each :class:`SequenceGroup`. + + Args: + model_input + output: sampler output + pinned_sampled_token_token_buffer: CPU-side pinned memory + (receives copy of + GPU-side token buffer.) + sampled_token_ids: GPU-side token buffer + logprobs_tensor: GPU-side tensor containing + logprobs computed during sampling """ assert model_input.frozen_model_input is not None @@ -434,20 +638,107 @@ def _pythonize_sampler_output(model_input: StatefulModelInput, sampling_metadata = frozen_model_input.sampling_metadata - for (seq_group, sample_result) in zip(sampling_metadata.seq_groups, - samples_list): - seq_ids = seq_group.seq_ids - next_token_ids = sample_result - parent_ids = [0] - seq_outputs: List[SequenceOutput] = [] + skip_sampler_cpu_output = ( + frozen_model_input.sampling_metadata.skip_sampler_cpu_output) + + # We are guaranteed output tensors are ready, so it is safe to + # pythonize the sampler output & obtain CPU-side logprobs. + # + # However this computation may be skipped entirely + # if no pythonization was deferred. + seq_groups = sampling_metadata.seq_groups + logprobs_are_requested = any([ + sg.sampling_params.logprobs is not None + or sg.sampling_params.prompt_logprobs is not None for sg in seq_groups + ]) + do_pythonize_logprobs = (skip_sampler_cpu_output + and logprobs_are_requested) + ( + prompt_logprobs, + sample_logprobs, + ) = (deferred_pythonize_logprobs(output, sampling_metadata, + logprobs_tensor) + if do_pythonize_logprobs else (None, None)) + + for sgdx, (seq_group, + sample_result) in enumerate(zip(seq_groups, samples_list)): if seq_group.sampling_params.logits_processors: assert len(seq_group.sampling_params.logits_processors) == 0, ( "Logits Processors are not supported in multi-step decoding") - for parent_id, next_token_id in zip(parent_ids, next_token_ids): - # TODO(will): support logprobs - # Hard coded logprob - seq_outputs.append( - SequenceOutput(seq_ids[parent_id], next_token_id, - {next_token_id: Logprob(logprob=-1)})) - output.outputs.append(CompletionSequenceGroupOutput(seq_outputs, None)) + + if do_pythonize_logprobs: + assert prompt_logprobs is not None + assert sample_logprobs is not None + + ( + group_prompt_logprobs, + group_sample_logprobs, + ) = ( # Utilize deferred pythonization results + prompt_logprobs[sgdx], + sample_logprobs[sgdx], + ) + elif logprobs_are_requested: + ( + group_prompt_logprobs, + group_sample_logprobs, + ) = ( + # profile_run: use already-computed logprobs + output.outputs[sgdx].prompt_logprobs, + [sample.logprobs for sample in output.outputs[sgdx].samples]) + + seq_ids = seq_group.seq_ids + next_token_ids = sample_result + parent_ids = [0] + + if cache is not None: + completion_seq_group_output: CompletionSequenceGroupOutput = \ + cache.cached_completion_seq_group_output.get_object() + completion_seq_group_output.samples.clear() + seq_outputs: List[ + SequenceOutput] = completion_seq_group_output.samples + else: + seq_outputs = [] + + for tdx, (parent_id, + next_token_id) in enumerate(zip(parent_ids, next_token_ids)): + if cache is not None: + seq_output: SequenceOutput = cache.cached_seq_output.get_object( + ) + seq_output.parent_seq_id = seq_ids[parent_id] + seq_output.output_token = next_token_id + + if logprobs_are_requested: + seq_output.logprobs = group_sample_logprobs[tdx] + else: + logprobs = next(iter(seq_output.logprobs.values())) + seq_output.logprobs.clear() + + logprobs.logprob = float('inf') + logprobs.rank = None + logprobs.decoded_token = None + + seq_output.logprobs[next_token_id] = logprobs + + seq_outputs.append(seq_output) + + else: + seq_outputs.append( + SequenceOutput(seq_ids[parent_id], next_token_id, + (group_sample_logprobs[tdx] + if logprobs_are_requested else { + next_token_id: + Logprob(logprob=float('inf'), + rank=None, + decoded_token=None) + }))) + if cache is not None: + completion_seq_group_output.prompt_logprobs = \ + group_prompt_logprobs if logprobs_are_requested else None + output.outputs.append(completion_seq_group_output) + else: + output.outputs.append( + CompletionSequenceGroupOutput( + seq_outputs, (group_prompt_logprobs + if logprobs_are_requested else None))) + assert len(output.outputs) > 0 diff --git a/vllm/worker/multi_step_worker.py b/vllm/worker/multi_step_worker.py index 2ed77dd698f5c..562285f828cc7 100644 --- a/vllm/worker/multi_step_worker.py +++ b/vllm/worker/multi_step_worker.py @@ -1,10 +1,12 @@ +import dataclasses from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import torch from vllm.distributed import broadcast_tensor_dict, get_pp_group -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.worker.model_runner_base import BroadcastableModelInput from vllm.worker.multi_step_model_runner import (MultiStepModelRunner, StatefulModelInput) @@ -61,6 +63,11 @@ def _get_driver_input_and_broadcast( execute_model_req.seq_group_metadata_list, execute_model_req.virtual_engine, execute_model_req.finished_requests_ids)) + + if execute_model_req.async_callback: + model_input.frozen_model_input = dataclasses.replace( # type: ignore + model_input.frozen_model_input, + async_callback=execute_model_req.async_callback) else: # on subsequent steps we reuse the worker input and model input multi_step_state = self.multi_step_states[virtual_engine] diff --git a/vllm/worker/neuron_model_runner.py b/vllm/worker/neuron_model_runner.py index 4f3fed2dbd723..0cf7445d4388d 100644 --- a/vllm/worker/neuron_model_runner.py +++ b/vllm/worker/neuron_model_runner.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from importlib.util import find_spec from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch @@ -8,11 +9,11 @@ SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.neuron import get_neuron_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.worker.model_runner_base import ModelRunnerBase, ModelRunnerInputBase @@ -76,9 +77,14 @@ def __init__( self.model: nn.Module # initialize after load_model. def load_model(self) -> None: - self.model = get_neuron_model(self.model_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + if find_spec("transformers_neuronx") is not None: + self.model = get_neuron_model( + self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) + else: + raise NotImplementedError( + "Supports only Transformer-NeuronX based models.") def _prepare_prompt( self, diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index a1d09a2f9e53e..f335e4e32efd4 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -11,10 +11,11 @@ SchedulerConfig) from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.openvino import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) -from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.sequence import SequenceGroupMetadata logger = init_logger(__name__) diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py index c47f9acc4423d..36339e175d7bb 100644 --- a/vllm/worker/openvino_worker.py +++ b/vllm/worker/openvino_worker.py @@ -14,7 +14,8 @@ init_distributed_environment) from vllm.logger import init_logger from vllm.model_executor import set_random_seed -from vllm.sequence import ExecuteModelRequest, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest from vllm.worker.openvino_model_runner import OpenVINOModelRunner from vllm.worker.worker_base import LoraNotSupportedWorkerBase diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 01daa64b5a32f..db306bc743d3a 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -1,6 +1,7 @@ import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, Union +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, + Type, Union) from unittest.mock import patch import numpy as np @@ -10,14 +11,15 @@ import torch_xla.runtime as xr from vllm.attention import AttentionMetadata, get_attn_backend +from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, ModelConfig, ParallelConfig, SchedulerConfig) from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors, - Logprob, SamplerOutput, SequenceGroupMetadata, - SequenceOutput) + Logprob, SequenceGroupMetadata, SequenceOutput) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, @@ -50,6 +52,7 @@ class ModelInputForTPU(ModelRunnerInputBase): best_of: List[int] seq_groups: List[List[int]] virtual_engine: int = 0 + async_callback: Optional[Callable] = None def as_broadcastable_tensor_dict( self) -> Dict[str, Union[int, torch.Tensor]]: @@ -144,11 +147,7 @@ def load_model(self) -> None: ) model = model.eval() xm.wait_device_ops() - model = ModelWrapper(model) - self.model = torch.compile(model, - backend="openxla", - fullgraph=True, - dynamic=False) + self.model = ModelWrapper(model) def _dummy_run( self, @@ -235,8 +234,15 @@ def _dummy_run( torch._dynamo.mark_dynamic(t, 0) torch._dynamo.mark_dynamic(p, 0) # Dummy run. - self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, - num_samples, kv_caches) + self.model(token_ids, + position_ids, + attn_metadata, + input_lens, + t, + p, + num_samples, + kv_caches, + is_prompt=is_prompt) def warmup_model( self, @@ -530,7 +536,7 @@ def _execute_model(*args): if getattr(arg, "context_lens", None) is not None: arg.context_lens = arg.context_lens.to(self.device) new_args.append(arg) - return self.model(*new_args) + return self.model(*new_args, is_prompt=is_prompt) num_prefills = model_input.attn_metadata.num_prefills is_prompt = num_prefills > 0 @@ -558,6 +564,8 @@ def _execute_model(*args): model_input.attn_metadata, model_input.input_lens[i:i + 1], model_input.t[i:i + 1], model_input.p[i:i + 1], model_input.num_samples, kv_caches) + if i == 0 and model_input.async_callback is not None: + model_input.async_callback() # Retrieve the outputs to CPU. next_token_ids += output_token_ids.cpu().tolist() start_idx = end_idx @@ -568,6 +576,8 @@ def _execute_model(*args): model_input.attn_metadata, model_input.input_lens, model_input.t, model_input.p, model_input.num_samples, kv_caches) + if model_input.async_callback is not None: + model_input.async_callback() # Retrieve the outputs to CPU. next_token_ids = output_token_ids.cpu().tolist() @@ -591,7 +601,7 @@ def _execute_model(*args): batch_idx += 1 else: for seq_id in seq_ids: - next_token_id = next_token_ids[batch_idx][0] + next_token_id = next_token_ids[batch_idx] seq_outputs.append( SequenceOutput(seq_id, next_token_id, {next_token_id: zero_logprob})) @@ -601,11 +611,32 @@ def _execute_model(*args): return [SamplerOutput(sampler_outputs)] -class ModelWrapper(nn.Module): +class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): def __init__(self, model: nn.Module): - super().__init__() self.model = model + compiled_callable = torch.compile(self.forward, + backend="openxla", + fullgraph=True, + dynamic=False) + super().__init__(compiled_callable) + + def __call__(self, *args, is_prompt: bool, **kwargs): + if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: + # not fully compiled yet, or not using the custom dispatcher, + # let PyTorch handle it + return self.compiled_callable(*args, **kwargs) + # the 3 compiled codes are: + # 0: for profiling + # 1: for prompt + # 2: for decode + # dispatch to the compiled code directly, skip PyTorch + if is_prompt: + with self.dispatch_to_code(1): + return self.forward(*args, **kwargs) + else: + with self.dispatch_to_code(2): + return self.forward(*args, **kwargs) def forward( self, @@ -691,6 +722,9 @@ def forward( sampled_token_ids = torch.multinomial(probs, num_samples, replacement=True) + if num_samples == 1: + argmax_token_ids = argmax_token_ids.squeeze(dim=-1) + sampled_token_ids = sampled_token_ids.squeeze(dim=-1) next_token_ids = torch.where(t != 0, sampled_token_ids, argmax_token_ids) return next_token_ids diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 320b15d3604bc..9e0c522cee453 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -102,8 +102,9 @@ def init_device(self) -> None: # NOTE(woosuk): Set per-rank cache path since different ranks # can have slightly different XLA graphs. world_size = self.parallel_config.world_size + rank = xr.global_ordinal() per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH, - f"tp{world_size}_rank{self.rank}") + f"tp{world_size}_rank{rank}") xr.initialize_cache(per_rank_path, readonly=False) def load_model(self): @@ -143,10 +144,6 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: num_cpu_blocks = int(self.cache_config.swap_space_bytes // block_size_bytes) num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8. - - # reset and discard the guard and compiled bytecode for profiling runs - torch._dynamo.reset() - return num_tpu_blocks, num_cpu_blocks def initialize_cache( diff --git a/vllm/worker/utils.py b/vllm/worker/utils.py index 79c48896469e8..d73023e8e1724 100644 --- a/vllm/worker/utils.py +++ b/vllm/worker/utils.py @@ -39,7 +39,7 @@ def assert_enc_dec_mr_supported_scenario( raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_PP']) - if enc_dec_mr.model_config.multimodal_config is not None: + if enc_dec_mr.model_config.is_multimodal_model: raise NotImplementedError( STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_MM']) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 7ed609c3b447c..0ff559a9af53e 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -17,12 +17,12 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.tensorizer import TensorizerConfig from vllm.platforms import current_platform from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, - SamplerOutput, SequenceGroupMetadata, - SequenceGroupMetadataDelta) + SequenceGroupMetadata, SequenceGroupMetadataDelta) from vllm.worker.cache_engine import CacheEngine from vllm.worker.embedding_model_runner import EmbeddingModelRunner from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 012043673b094..6ba4f272315ce 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -11,9 +11,9 @@ from vllm.distributed import broadcast_tensor_dict, get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform -from vllm.sequence import (ExecuteModelRequest, IntermediateTensors, - SamplerOutput) +from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import (enable_trace_function_call_for_thread, update_environment_variables) from vllm.worker.model_runner_base import (BroadcastableModelInput, diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 3894658a095f3..f9037625d4af9 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -15,12 +15,12 @@ from vllm.distributed import get_pp_group from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs, MultiModalRegistry) from vllm.sampling_params import SamplingParams -from vllm.sequence import (IntermediateTensors, SamplerOutput, - SequenceGroupMetadata) +from vllm.sequence import IntermediateTensors, SequenceGroupMetadata from vllm.utils import CudaMemoryProfiler, make_tensor_with_pad from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata from vllm.worker.model_runner_base import (