From ec964bde37f8615c3870c2f00a8e971a338ddc29 Mon Sep 17 00:00:00 2001 From: bikash119 Date: Fri, 4 Oct 2024 06:21:13 +0000 Subject: [PATCH 01/10] First commit: sglang support --- src/distilabel/llms/sglang.py | 211 ++++++++++++++++++++++++++++++++++ 1 file changed, 211 insertions(+) create mode 100644 src/distilabel/llms/sglang.py diff --git a/src/distilabel/llms/sglang.py b/src/distilabel/llms/sglang.py new file mode 100644 index 000000000..e78ba8d88 --- /dev/null +++ b/src/distilabel/llms/sglang.py @@ -0,0 +1,211 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, +) + +from pydantic import Field, PrivateAttr, validate_call + +from distilabel.llms.base import LLM +from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin +from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin +from distilabel.llms.typing import GenerateOutput +from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType + +if TYPE_CHECKING: + from openai import OpenAI # noqa + + from distilabel.steps.tasks.typing import StandardInput + + +class SGLang(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): + """`SGLang` library LLM implementation. + + Attributes: + model (str): The model Hugging Face Hub repo id or a path to a directory containing the + model weights and configuration files. + tokenizer_path (Optional[str]): Path to the tokenizer. If None, the default tokenizer for + the model will be used. + tokenizer_mode (str): Mode for tokenizer initialization. Default is "auto". + skip_tokenizer_init (bool): Whether to skip tokenizer initialization. Default is False. + load_format (str): Format for loading the model. Default is "auto". + dtype (str): Data type for model parameters. Default is "auto". + kv_cache_dtype (str): Data type for key-value cache. Default is "auto". + trust_remote_code (bool): Whether to trust remote code when loading the model. Default is True. + context_length (Optional[int]): Maximum context length for the model. If None, uses the + model's default. + quantization (Optional[str]): Quantization method to use. If None, no quantization is applied. + served_model_name (Optional[str]): Name of the served model if using a model server. + chat_template (Optional[str]): Custom chat template to use for formatting inputs. + is_embedding (bool): Whether the model is used for embeddings. Default is False. + + Runtime parameters: + - extra_kwargs: Additional dictionary of keyword arguments that will be passed to the + SGLang class. + - structured_output: The structured output format to use across all the generations. + - log_level: The log level to use for the SGLang server. + + Examples: + Generate text: + + ```python + from distilabel.llms import SGLang + + llm = SGLang(model="your-model-name") + llm.load() + + output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]]) + ``` + """ + + model: str + dtype: str = "auto" + trust_remote_code: bool = False + quantization: Optional[str] = None + revision: Optional[str] = None + + tokenizer_path: Optional[str] = None + tokenizer_mode: str = "auto" + tokenizer_revision: Optional[str] = None + skip_tokenizer_init: bool = False + chat_template: Optional[str] = None + + load_format: str = "auto" + kv_cache_dtype: str = "auto" + context_length: Optional[int] = None + served_model_name: Optional[str] = None + is_embedding: bool = False + + seed: int = 0 + + extra_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field( + default_factory=dict, + description="Additional dictionary of keyword arguments that will be passed to the" + " `SGLang` class.", + ) + structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field( + default=None, + description="The structured output format to use across all the generations.", + ) + log_level: Optional[RuntimeParameter[str]] = Field( + default="error", + description="The log level to use for the SGLang server.", + ) + + _model: Any = PrivateAttr(None) + _tokenizer: Any = PrivateAttr(None) + + def load(self) -> None: + """ + Loads the SGLang model using either path or Huggingface repository id. + Additionally, this method also sets the `chat_template` for the tokenizer, so as to properly + parse the list of OpenAI formatted inputs using the expected format by the model, otherwise, the + default value is ChatML format, unless explicitly provided. + """ + super().load() + CudaDevicePlacementMixin.load(self) + + try: + from sglang.srt.server import Runtime + except ImportError as ie: + raise ImportError( + 'SGLang is not installed. Please install it using `pip install "sglang[all]"`.' + " Also, install FlashInfer CUDA kernels using:\n" + "`pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/`" + ) from ie + + self._model = Runtime( + model_path=self.model, + dtype=self.dtype, + trust_remote_code=self.trust_remote_code, + quantization=self.quantization, + revision=self.revision, + tokenizer=self.tokenizer, + tokenizer_mode=self.tokenizer_mode, + tokenizer_revision=self.tokenizer_revision, + skip_tokenizer_init=self.skip_tokenizer_init, + load_format=self.load_format, + kv_cache_dtype=self.kv_cache_dtype, + context_length=self.context_length, + served_model_name=self.served_model_name, + is_embedding=self.is_embedding, + seed=self.seed, + **self.extra_kwargs, + ) + + self._tokenizer = self._model.get_tokenizer() # type: ignore + if self.chat_template is not None: + self._tokenizer.chat_template = self.chat_template # type: ignore + + if self.structured_output: + self._structured_output_logits_processor = self._prepare_structured_output( + self.structured_output + ) + + def unload(self) -> None: + """Unloads the SGLang model.""" + self._model = None + self._tokenizer = None + CudaDevicePlacementMixin.unload(self) + super().unload() + + @property + def model_name(self) -> str: + """Returns the model name used for the LLM.""" + return self.served_model_name + + def prepare_input(self, input: "StandardInput") -> str: + """Prepares the input (applying the chat template and tokenization) for the provided + input. + + Args: + input: the input list containing chat items. + + Returns: + The prompt to send to the LLM. + """ + if self._tokenizer.chat_template is None: + return input[0]["content"] + + prompt: str = ( + self._tokenizer.apply_chat_template( + input, # type: ignore + tokenize=False, + add_generation_prompt=True, # type: ignore + ) + if input + else "" + ) + return super().apply_magpie_pre_query_template(prompt, input) + + @validate_call + def generate( + self, + inputs: List[FormattedInput], + num_generations: int = 1, + max_new_tokens: int = 128, + # Add other relevant parameters here + ) -> List[GenerateOutput]: + """Generates responses for each input.""" + # Implement generation logic here + pass + + +# You can add a ClientSGLang class here if needed, similar to ClientvLLM From ad7bac04c1833f7d725c99ed3830dc5e391293f6 Mon Sep 17 00:00:00 2001 From: Bikash Date: Sat, 5 Oct 2024 04:58:14 +0000 Subject: [PATCH 02/10] Test case added --- src/distilabel/llms/__init__.py | 2 + src/distilabel/llms/sglang.py | 196 +++++++++++++++++++++++++++++--- tests/unit/llms/test_sglang.py | 41 +++++++ 3 files changed, 223 insertions(+), 16 deletions(-) create mode 100644 tests/unit/llms/test_sglang.py diff --git a/src/distilabel/llms/__init__.py b/src/distilabel/llms/__init__.py index 526d6b1fa..4b65548be 100644 --- a/src/distilabel/llms/__init__.py +++ b/src/distilabel/llms/__init__.py @@ -26,6 +26,7 @@ from distilabel.llms.moa import MixtureOfAgentsLLM from distilabel.llms.ollama import OllamaLLM from distilabel.llms.openai import OpenAILLM +from distilabel.llms.sglang import SGLang from distilabel.llms.together import TogetherLLM from distilabel.llms.typing import GenerateOutput, HiddenState from distilabel.llms.vertexai import VertexAILLM @@ -54,4 +55,5 @@ "VertexAILLM", "ClientvLLM", "vLLM", + "SGLang", ] diff --git a/src/distilabel/llms/sglang.py b/src/distilabel/llms/sglang.py index e78ba8d88..49d57b8f6 100644 --- a/src/distilabel/llms/sglang.py +++ b/src/distilabel/llms/sglang.py @@ -15,11 +15,14 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Dict, List, Optional, + Union, ) +import numpy as np from pydantic import Field, PrivateAttr, validate_call from distilabel.llms.base import LLM @@ -30,10 +33,18 @@ from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType if TYPE_CHECKING: - from openai import OpenAI # noqa + from sglang.srt.server import Runtime + from transformers import PreTrainedTokenizer from distilabel.steps.tasks.typing import StandardInput +LogitsProcessorFn = Union[ + Callable[[List[int], Any], Any], + Callable[[List[int], List[int], Any], Any], +] + +LogitsProcessors = List[LogitsProcessorFn] + class SGLang(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): """`SGLang` library LLM implementation. @@ -79,11 +90,9 @@ class SGLang(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): dtype: str = "auto" trust_remote_code: bool = False quantization: Optional[str] = None - revision: Optional[str] = None - tokenizer_path: Optional[str] = None + tokenizer: Optional[str] = None tokenizer_mode: str = "auto" - tokenizer_revision: Optional[str] = None skip_tokenizer_init: bool = False chat_template: Optional[str] = None @@ -109,8 +118,8 @@ class SGLang(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): description="The log level to use for the SGLang server.", ) - _model: Any = PrivateAttr(None) - _tokenizer: Any = PrivateAttr(None) + _model: "Runtime" = PrivateAttr(None) + _tokenizer: "PreTrainedTokenizer" = PrivateAttr(None) def load(self) -> None: """ @@ -126,7 +135,7 @@ def load(self) -> None: from sglang.srt.server import Runtime except ImportError as ie: raise ImportError( - 'SGLang is not installed. Please install it using `pip install "sglang[all]"`.' + '`SGLang` is not installed. Please install it using `pip install "sglang[all]"`.' " Also, install FlashInfer CUDA kernels using:\n" "`pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/`" ) from ie @@ -136,17 +145,15 @@ def load(self) -> None: dtype=self.dtype, trust_remote_code=self.trust_remote_code, quantization=self.quantization, - revision=self.revision, - tokenizer=self.tokenizer, + tokenizer_path=self.tokenizer, tokenizer_mode=self.tokenizer_mode, - tokenizer_revision=self.tokenizer_revision, skip_tokenizer_init=self.skip_tokenizer_init, load_format=self.load_format, kv_cache_dtype=self.kv_cache_dtype, context_length=self.context_length, served_model_name=self.served_model_name, is_embedding=self.is_embedding, - seed=self.seed, + random_seed=self.seed, **self.extra_kwargs, ) @@ -196,16 +203,173 @@ def prepare_input(self, input: "StandardInput") -> str: return super().apply_magpie_pre_query_template(prompt, input) @validate_call - def generate( + def generate( # type: ignore self, inputs: List[FormattedInput], num_generations: int = 1, max_new_tokens: int = 128, - # Add other relevant parameters here + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repetition_penalty: float = 1.0, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + min_p: float = 0.0, + stop: Optional[List[str]] = None, + stop_token_ids: Optional[List[int]] = None, + include_stop_str_in_output: bool = False, + logits_processors: Optional[LogitsProcessors] = None, + extra_sampling_params: Optional[Dict[str, Any]] = None, ) -> List[GenerateOutput]: - """Generates responses for each input.""" - # Implement generation logic here - pass + """Generates `num_generations` responses for each input. + + Args: + inputs: a list of inputs in chat format to generate responses for. + num_generations: the number of generations to create per input. Defaults to + `1`. + max_new_tokens: the maximum number of new tokens that the model will generate. + Defaults to `128`. + presence_penalty: the presence penalty to use for the generation. Defaults to + `0.0`. + frequency_penalty: the repetition penalty to use for the generation. Defaults + to `0.0`. + repetition_penalty: the repetition penalty to use for the generation Defaults to + `1.0`. + temperature: the temperature to use for the generation. Defaults to `0.1`. + top_p: the top-p value to use for the generation. Defaults to `1.0`. + top_k: the top-k value to use for the generation. Defaults to `0`. + min_p: the minimum probability to use for the generation. Defaults to `0.0`. + stop: a list of strings that will be used to stop the generation when found. + Defaults to `None`. + stop_token_ids: a list of token ids that will be used to stop the generation + when found. Defaults to `None`. + include_stop_str_in_output: whether to include the stop string in the output. + Defaults to `False`. + logits_processors: a list of functions to process the logits before sampling. + Defaults to `None`. + extra_sampling_params: dictionary with additional arguments to be passed to + the `SamplingParams` class from `vllm`. + + Returns: + A list of lists of strings containing the generated responses for each input. + """ + + if not logits_processors: + logits_processors = [] + + if extra_sampling_params is None: + extra_sampling_params = {} + + structured_output = None + + if isinstance(inputs[0], tuple): + prepared_batches, sorted_indices = self._prepare_batches(inputs) + else: + # Simulate a batch without the structured output content + prepared_batches = [([self.prepare_input(input) for input in inputs], None)] + sorted_indices = None + + # Case in which we have a single structured output for the dataset + if self._structured_output_logits_processor: + logits_processors.append(self._structured_output_logits_processor) + + batched_outputs = [] + + for prepared_inputs, structured_output in prepared_batches: + if structured_output: + logits_processors.append( + self._prepare_structured_output(structured_output) + ) + + sampling_params = {"max_new_tokens": 128} + + self._model.generate( + prepared_inputs, + sampling_params, + use_tqdm=False, # type: ignore + ) + batch_outputs = self._model.add_request( + prepared_inputs, + sampling_params, + use_tqdm=False, # type: ignore + ) + + batched_outputs += [ + [output.text for output in outputs.outputs] for outputs in batch_outputs + ] + + # If logits_processor is set, we need to sort the outputs back to the original order + # (would be needed only if we have multiple structured outputs in the dataset) + if sorted_indices is not None: + batched_outputs = _sort_batches( + batched_outputs, sorted_indices, num_generations=num_generations + ) + return batched_outputs + + def _prepare_structured_output( + self, structured_output: Optional[OutlinesStructuredOutputType] = None + ) -> Union[Callable, None]: + """Creates the appropriate function to filter tokens to generate structured outputs. + + Args: + structured_output: the configuration dict to prepare the structured output. + + Returns: + The callable that will be used to guide the generation of the model. + """ + from distilabel.steps.tasks.structured_outputs.outlines import ( + prepare_guided_output, + ) + + result = prepare_guided_output(structured_output, "vllm", self._model) + if (schema := result.get("schema")) and self.structured_output: + self.structured_output["schema"] = schema + return result["processor"] + + +def _sort_batches( + batches: List[List[FormattedInput]], indices: List[int], num_generations: int = 1 +) -> List[str]: + """Helper function to sort back the mini-batches generated by the model. + + It must take into account the number of `num_generations` to repeat the indices + accordingly. + + Args: + batches: The mini-batches generated by the model. + indices: The indices that would sort the mini-batches back to the original order. + num_generations: The number of generations requested to vLLM. Defaults to 1. + + Returns: + Sorted batched_outputs. + """ + batch_sizes = [len(batch) for batch in batches] + flattened_batches = np.array([b for batch in batches for b in batch]) + sorted_batches = np.take_along_axis( + flattened_batches, + np.argsort(np.repeat(indices, num_generations)), + axis=0, + ).tolist() + sorted_batches = _batchify(sorted_batches, batch_sizes) + return sorted_batches + + +def _batchify(sorted_batches: List[str], batch_sizes: List[int]) -> List[List[str]]: + """Helper function to regenerate the sorted batches from the flattened sorted ones. + + Args: + sorted_batches: Output obtained from the `_sort_batches` function. + batch_sizes: The batch sizes to be used to split the sorted batches. + + Returns: + Batched sorted batches in the original shape. + """ + batches = [] + idx = 0 + for bs in batch_sizes: + batches.append(sorted_batches[idx : idx + bs]) + idx += bs + return batches # You can add a ClientSGLang class here if needed, similar to ClientvLLM diff --git a/tests/unit/llms/test_sglang.py b/tests/unit/llms/test_sglang.py new file mode 100644 index 000000000..6b7388995 --- /dev/null +++ b/tests/unit/llms/test_sglang.py @@ -0,0 +1,41 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from distilabel.llms import SGLang + + +@pytest.fixture +def sglang_instance(): + return SGLang(model="test-model") + + +def test_sglang_init(): + """Test the initialization of SGLang class.""" + llm = SGLang(model="test-model") + assert llm.model == "test-model" + assert llm.dtype == "auto" + assert llm.quantization is None + + +def test_sglang_load(): + """Test the load method of SGLang class.""" + llm = SGLang(model="meta-llama/Llama-2-7b-chat-hf") + llm.load() + assert llm._model is not None + assert llm._tokenizer is not None + + +# Add more tests as needed for other methods and edge cases From c3a532237cd5b66beb4cb51f0051052d505e9638 Mon Sep 17 00:00:00 2001 From: Bikash Date: Sat, 5 Oct 2024 05:56:01 +0000 Subject: [PATCH 03/10] Align the testcases in line with test_vllm.py testcase --- src/distilabel/llms/sglang.py | 24 +++-- tests/unit/llms/test_sglang.py | 191 ++++++++++++++++++++++++++++++--- 2 files changed, 192 insertions(+), 23 deletions(-) diff --git a/src/distilabel/llms/sglang.py b/src/distilabel/llms/sglang.py index 49d57b8f6..90a35ba8f 100644 --- a/src/distilabel/llms/sglang.py +++ b/src/distilabel/llms/sglang.py @@ -253,6 +253,7 @@ def generate( # type: ignore Returns: A list of lists of strings containing the generated responses for each input. """ + from sglang.srt.sampling.sampling_params import SamplingParams if not logits_processors: logits_processors = [] @@ -281,14 +282,23 @@ def generate( # type: ignore self._prepare_structured_output(structured_output) ) - sampling_params = {"max_new_tokens": 128} - - self._model.generate( - prepared_inputs, - sampling_params, - use_tqdm=False, # type: ignore + sampling_params = SamplingParams( # type: ignore + n=num_generations, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + max_new_tokens=max_new_tokens, + stop=stop, + stop_token_ids=stop_token_ids, + include_stop_str_in_output=include_stop_str_in_output, + **extra_sampling_params, ) - batch_outputs = self._model.add_request( + + batch_outputs = self._model.generate( prepared_inputs, sampling_params, use_tqdm=False, # type: ignore diff --git a/tests/unit/llms/test_sglang.py b/tests/unit/llms/test_sglang.py index 6b7388995..e9bdc2301 100644 --- a/tests/unit/llms/test_sglang.py +++ b/tests/unit/llms/test_sglang.py @@ -12,30 +12,189 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List + +import numpy as np import pytest +from pydantic import BaseModel + +from distilabel.llms.sglang import SGLang, _sort_batches + + +class Character(BaseModel): + name: str + description: str + role: str + weapon: str + + +class Animal(BaseModel): + name: str + species: str + habitat: str + diet: str + + +SAMPLE_DATA = [ + [ + { + "instruction": [ + {"role": "user", "content": "Generate a character from a RPG game."} + ], + "structured_output": { + "format": "json", + "schema": Character.model_json_schema(), + }, + }, + { + "instruction": [ + { + "role": "user", + "content": "Generate an animal from a zoo.", + } + ], + "structured_output": { + "format": "json", + "schema": Animal.model_json_schema(), + }, + }, + { + "instruction": [{"role": "user", "content": "Repeated character"}], + "structured_output": { + "format": "json", + "schema": Character.model_json_schema(), + }, + }, + { + "instruction": [ + { + "role": "user", + "content": "What's the weather like today in Seattle in Celsius degrees?", + } + ], + "structured_output": { + "format": "regex", + "schema": "(\\d{1,2})°C", + }, + }, + { + "instruction": [{"role": "user", "content": "Other character"}], + "structured_output": { + "format": "json", + "schema": Character.model_json_schema(), + }, + }, + { + "instruction": [{"role": "user", "content": "repeated regex"}], + "structured_output": { + "format": "regex", + "schema": "(\\d{1,2})°C", + }, + }, + ] +] + + +class DummyTokenizer: + chat_template = None -from distilabel.llms import SGLang + def __init__(self) -> None: + pass + def apply_chat_template(self, input, **kwargs): + return input -@pytest.fixture -def sglang_instance(): - return SGLang(model="test-model") +class TestvLLM: + @pytest.mark.parametrize( + "num_generations, expected_sorted_batches", + [ + ( + 1, + [ + "Generate a character from a RPG game.", + "Generate an animal from a zoo.", + "Repeated character", + "What's the weather like today in Seattle in Celsius degrees?", + "Other character", + "repeated regex", + ], + ), + ( + 3, + np.repeat( + [ + "Generate a character from a RPG game.", + "Generate an animal from a zoo.", + "Repeated character", + "What's the weather like today in Seattle in Celsius degrees?", + "Other character", + "repeated regex", + ], + 3, + ).tolist(), + ), + ], + ) + def test_prepare_batches_and_sort_back( + self, num_generations: int, expected_sorted_batches: List[str] + ): + formatted_inputs = [ + (item["instruction"], item["structured_output"]) + for row in SAMPLE_DATA + for item in row + ] + llm = SGLang(model="dummy") + llm._tokenizer = DummyTokenizer() + batches, indices = llm._prepare_batches(formatted_inputs) + # NOTE: We have to simulate calling self._model.generate(n=num_generations) and then sorting the results + num_generations_batches = [] + for batch in batches: + num_generations_batches.append( + (np.repeat(batch[0], num_generations).tolist(), batch[1]) + ) + batches = num_generations_batches + # Recreate as the output from batched_outputs += [[output.text for output in outputs.outputs] for outputs in batch_outputs] + batches = [batch for batch, _ in batches] + sorted_batches = _sort_batches( + batches, indices, num_generations=num_generations + ) -def test_sglang_init(): - """Test the initialization of SGLang class.""" - llm = SGLang(model="test-model") - assert llm.model == "test-model" - assert llm.dtype == "auto" - assert llm.quantization is None + assert sorted_batches == [ + np.repeat( + [ + "Generate a character from a RPG game.", + "Generate an animal from a zoo.", + "Repeated character", + ], + num_generations, + ).tolist(), + np.repeat( + ["What's the weather like today in Seattle in Celsius degrees?"], + num_generations, + ).tolist(), + np.repeat( + [ + "Other character", + "repeated regex", + ], + num_generations, + ).tolist(), + ] + def test_sglang_init(self): + """Test the initialization of SGLang class.""" + llm = SGLang(model="test-model") + assert llm.model == "test-model" + assert llm.dtype == "auto" + assert llm.quantization is None -def test_sglang_load(): - """Test the load method of SGLang class.""" - llm = SGLang(model="meta-llama/Llama-2-7b-chat-hf") - llm.load() - assert llm._model is not None - assert llm._tokenizer is not None + def test_sglang_load(self): + """Test the load method of SGLang class.""" + llm = SGLang(model="meta-llama/Llama-2-7b-chat-hf") + llm.load() + assert llm._model is not None + assert llm._tokenizer is not None # Add more tests as needed for other methods and edge cases From 4407d9aec48116f652f81148661fd7f71cac8ed2 Mon Sep 17 00:00:00 2001 From: Bikash Date: Sun, 6 Oct 2024 13:42:34 +0000 Subject: [PATCH 04/10] fixes --- src/distilabel/llms/sglang.py | 44 +++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/src/distilabel/llms/sglang.py b/src/distilabel/llms/sglang.py index 90a35ba8f..83e2e8c60 100644 --- a/src/distilabel/llms/sglang.py +++ b/src/distilabel/llms/sglang.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from typing import ( TYPE_CHECKING, Any, @@ -19,6 +20,7 @@ Dict, List, Optional, + Tuple, Union, ) @@ -202,6 +204,48 @@ def prepare_input(self, input: "StandardInput") -> str: ) return super().apply_magpie_pre_query_template(prompt, input) + def _prepare_batches( + self, inputs: List[FormattedInput] + ) -> Tuple[List[List[FormattedInput]], List[int]]: + """Prepares the inputs by grouping them by the structured output. + + When we generate structured outputs with schemas obtained from a dataset, we need to + prepare the data to try to send batches of inputs instead of single inputs to the model + to take advante of the engine. So we group the inputs by the structured output to be + passed in the `generate` method. + + Args: + inputs: The batch of inputs passed to the generate method. As we expect to be generating + structured outputs, each element will be a tuple containing the instruction and the + structured output. + + Returns: + The prepared batches (sub-batches let's say) to be passed to the `generate` method. + Each new tuple will contain instead of the single instruction, a list of instructions + """ + instruction_order = {} + batches = {} + for i, (instruction, structured_output) in enumerate(inputs): + instruction = self.prepare_input(instruction) + instruction_order[instruction] = i + structured_output = json.dumps(structured_output) + if structured_output not in batches: + batches[structured_output] = [instruction] + else: + batches[structured_output].append(instruction) + + # Flatten the instructions in prepared_data + flat_instructions = [ + instruction for _, group in batches.items() for instruction in group + ] + # Generate the list of indices based on the original order + sorted_indices = [ + instruction_order[instruction] for instruction in flat_instructions + ] + return [ + (batch, json.loads(schema)) for schema, batch in batches.items() + ], sorted_indices + @validate_call def generate( # type: ignore self, From b530a23d0dd4d971933c88cfd2d2cf8f1ac0e4a0 Mon Sep 17 00:00:00 2001 From: Bikash Date: Sun, 6 Oct 2024 15:28:32 +0000 Subject: [PATCH 05/10] generate method added --- src/distilabel/llms/sglang.py | 13 ++++++------ tests/unit/llms/test_sglang.py | 36 ++++++++++++++++++++++++++++++++-- 2 files changed, 41 insertions(+), 8 deletions(-) diff --git a/src/distilabel/llms/sglang.py b/src/distilabel/llms/sglang.py index 83e2e8c60..768d51f05 100644 --- a/src/distilabel/llms/sglang.py +++ b/src/distilabel/llms/sglang.py @@ -122,6 +122,7 @@ class SGLang(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): _model: "Runtime" = PrivateAttr(None) _tokenizer: "PreTrainedTokenizer" = PrivateAttr(None) + _structured_output_logits_processor: Optional[Callable] = PrivateAttr(default=None) def load(self) -> None: """ @@ -170,6 +171,7 @@ def load(self) -> None: def unload(self) -> None: """Unloads the SGLang model.""" + self._model.shutdown() self._model = None self._tokenizer = None CudaDevicePlacementMixin.unload(self) @@ -261,7 +263,6 @@ def generate( # type: ignore min_p: float = 0.0, stop: Optional[List[str]] = None, stop_token_ids: Optional[List[int]] = None, - include_stop_str_in_output: bool = False, logits_processors: Optional[LogitsProcessors] = None, extra_sampling_params: Optional[Dict[str, Any]] = None, ) -> List[GenerateOutput]: @@ -314,6 +315,7 @@ def generate( # type: ignore prepared_batches = [([self.prepare_input(input) for input in inputs], None)] sorted_indices = None + # Case in which we have a single structured output for the dataset # Case in which we have a single structured output for the dataset if self._structured_output_logits_processor: logits_processors.append(self._structured_output_logits_processor) @@ -337,19 +339,18 @@ def generate( # type: ignore min_p=min_p, max_new_tokens=max_new_tokens, stop=stop, - stop_token_ids=stop_token_ids, - include_stop_str_in_output=include_stop_str_in_output, + stop_token_ids=[] if stop_token_ids is None else stop_token_ids, **extra_sampling_params, ) batch_outputs = self._model.generate( prepared_inputs, - sampling_params, - use_tqdm=False, # type: ignore + sampling_params.to_srt_kwargs(), ) batched_outputs += [ - [output.text for output in outputs.outputs] for outputs in batch_outputs + [output.text for output in outputs.outputs] + for outputs in json.loads(batch_outputs) ] # If logits_processor is set, we need to sort the outputs back to the original order diff --git a/tests/unit/llms/test_sglang.py b/tests/unit/llms/test_sglang.py index e9bdc2301..0742d5ad2 100644 --- a/tests/unit/llms/test_sglang.py +++ b/tests/unit/llms/test_sglang.py @@ -105,7 +105,7 @@ def apply_chat_template(self, input, **kwargs): return input -class TestvLLM: +class TestSGLang: @pytest.mark.parametrize( "num_generations, expected_sorted_batches", [ @@ -195,6 +195,38 @@ def test_sglang_load(self): llm.load() assert llm._model is not None assert llm._tokenizer is not None + llm.unload() + def test_sglang_generate(self): + """Test the generate method of SGLang class.""" + llm = SGLang(model="meta-llama/Llama-2-7b-chat-hf") + llm.load() + + # Mock the _model.generate method to avoid actual API calls + def mock_generate(inputs, sampling_params, use_tqdm): + from collections import namedtuple + + Output = namedtuple("Output", ["text"]) + Outputs = namedtuple("Outputs", ["outputs"]) + return [ + Outputs( + [ + Output(f"Generated text for {input}") + for _ in range(sampling_params.n) + ] + ) + for input in inputs + ] + + # llm.generate = mock_generate + + inputs = [ + [{"role": "user", "content": "Hello, how are you?"}], + [{"role": "user", "content": "What's the weather like today?"}], + ] + + outputs = llm.generate(inputs, num_generations=2, max_new_tokens=10) + assert len(outputs) == 2 # Two input prompts + llm.unload() -# Add more tests as needed for other methods and edge cases + # Add more tests as needed for other methods and edge cases From ad48fa318671636f59fe8fb48512dc073bfac93a Mon Sep 17 00:00:00 2001 From: Bikash Date: Sun, 6 Oct 2024 15:42:25 +0000 Subject: [PATCH 06/10] batched output supported for generate method --- src/distilabel/llms/sglang.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/distilabel/llms/sglang.py b/src/distilabel/llms/sglang.py index 768d51f05..ffccc3307 100644 --- a/src/distilabel/llms/sglang.py +++ b/src/distilabel/llms/sglang.py @@ -348,10 +348,7 @@ def generate( # type: ignore sampling_params.to_srt_kwargs(), ) - batched_outputs += [ - [output.text for output in outputs.outputs] - for outputs in json.loads(batch_outputs) - ] + batched_outputs += [output["text"] for output in json.loads(batch_outputs)] # If logits_processor is set, we need to sort the outputs back to the original order # (would be needed only if we have multiple structured outputs in the dataset) From 38ca6dbede1f63a7dcd8626123c502db7d7d5340 Mon Sep 17 00:00:00 2001 From: bikash119 Date: Mon, 7 Oct 2024 05:36:49 +0530 Subject: [PATCH 07/10] Dependencies added to pyproject.toml --- pyproject.toml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 44404c683..c3a1b4219 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -101,6 +101,14 @@ text-clustering = [ "scikit-learn >= 1.4.1", "matplotlib >= 3.8.3" # For the figure (even though it's optional) ] +sglang = [ + "sglang >= 0.3.2", + "flashinfer @ https://flashinfer.ai/whl/cu121/torch2.4/flashinfer-1.1.0-cp38-cp38-linux_x86_64.whl ; python_version == '3.8'", + "flashinfer @ https://flashinfer.ai/whl/cu121/torch2.4/flashinfer-1.1.0-cp39-cp39-linux_x86_64.whl ; python_version == '3.9'", + "flashinfer @ https://flashinfer.ai/whl/cu121/torch2.4/flashinfer-1.1.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", + "flashinfer @ https://flashinfer.ai/whl/cu121/torch2.4/flashinfer-1.1.0-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", + "flashinfer @ https://flashinfer.ai/whl/cu121/torch2.4/flashinfer-1.1.0-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", +] # minhash minhash = ["datasketch >= 1.6.5", "nltk>3.8.1"] From 41f0bcfb56506f37516d2d6cd2420e5cd2619b87 Mon Sep 17 00:00:00 2001 From: bikash Date: Sat, 26 Oct 2024 14:05:45 +0000 Subject: [PATCH 08/10] simplest working example --- pyproject.toml | 3 +- src/distilabel/models/__init__.py | 2 - src/distilabel/models/llms/__init__.py | 2 + src/distilabel/models/llms/sglang.py | 469 +++++-------------------- 4 files changed, 95 insertions(+), 381 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d1b337fe0..d0b6715a8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -103,7 +103,8 @@ text-clustering = [ "matplotlib >= 3.8.3", # For the figure (even though it's optional) ] sglang = [ - "sglang >= 0.3.2", + "sglang[all]", + "transformers >= 4.34.1", "flashinfer @https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.3/flashinfer-0.1.3+cu121torch2.4-cp310-cp310-linux_x86_64.whl#sha256=dc9ed41c47e65abc368b16b27cedf9391ba51a6bebdea3485808321958cc36c2", ] diff --git a/src/distilabel/models/__init__.py b/src/distilabel/models/__init__.py index 4bf520c3e..45807302f 100644 --- a/src/distilabel/models/__init__.py +++ b/src/distilabel/models/__init__.py @@ -31,7 +31,6 @@ from distilabel.models.llms.moa import MixtureOfAgentsLLM from distilabel.models.llms.ollama import OllamaLLM from distilabel.models.llms.openai import OpenAILLM -from distilabel.models.llms.sglang import SGLang from distilabel.models.llms.together import TogetherLLM from distilabel.models.llms.typing import GenerateOutput, HiddenState from distilabel.models.llms.vertexai import VertexAILLM @@ -64,5 +63,4 @@ "Embeddings", "SentenceTransformerEmbeddings", "vLLMEmbeddings", - "SGLang", ] diff --git a/src/distilabel/models/llms/__init__.py b/src/distilabel/models/llms/__init__.py index 2ae311983..fcc68b799 100644 --- a/src/distilabel/models/llms/__init__.py +++ b/src/distilabel/models/llms/__init__.py @@ -25,6 +25,7 @@ from distilabel.models.llms.moa import MixtureOfAgentsLLM from distilabel.models.llms.ollama import OllamaLLM from distilabel.models.llms.openai import OpenAILLM +from distilabel.models.llms.sglang import SGLangLLM from distilabel.models.llms.together import TogetherLLM from distilabel.models.llms.typing import GenerateOutput, HiddenState from distilabel.models.llms.vertexai import VertexAILLM @@ -54,4 +55,5 @@ "VertexAILLM", "ClientvLLM", "vLLM", + "SGLangLLM", ] diff --git a/src/distilabel/models/llms/sglang.py b/src/distilabel/models/llms/sglang.py index 5b8c4dfe2..3d8a21dc5 100644 --- a/src/distilabel/models/llms/sglang.py +++ b/src/distilabel/models/llms/sglang.py @@ -13,415 +13,128 @@ # limitations under the License. import json -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - List, - Optional, - Tuple, - Union, -) +from typing import TYPE_CHECKING, Any, List -import numpy as np -from pydantic import Field, PrivateAttr, validate_call +from pydantic import PrivateAttr, validate_call -from distilabel.mixins.runtime_parameters import RuntimeParameter -from distilabel.models.llms.base import LLM -from distilabel.models.llms.typing import GenerateOutput -from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin -from distilabel.models.mixins.magpie import MagpieChatTemplateMixin -from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType +from distilabel.models.llms import LLM +from distilabel.models.llms.typing import GenerateOutput, HiddenState +from distilabel.steps.tasks.typing import ChatType if TYPE_CHECKING: from sglang.srt.server import Runtime - from transformers import PreTrainedTokenizer - from distilabel.steps.tasks.typing import StandardInput -LogitsProcessorFn = Union[ - Callable[[List[int], Any], Any], - Callable[[List[int], List[int], Any], Any], -] +class SGLangLLM(LLM): + _runtime: "Runtime" = PrivateAttr(None) -LogitsProcessors = List[LogitsProcessorFn] - - -class SGLang(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): - """`SGLang` library LLM implementation. - - Attributes: - model (str): The model Hugging Face Hub repo id or a path to a directory containing the - model weights and configuration files. - tokenizer_path (Optional[str]): Path to the tokenizer. If None, the default tokenizer for - the model will be used. - tokenizer_mode (str): Mode for tokenizer initialization. Default is "auto". - skip_tokenizer_init (bool): Whether to skip tokenizer initialization. Default is False. - load_format (str): Format for loading the model. Default is "auto". - dtype (str): Data type for model parameters. Default is "auto". - kv_cache_dtype (str): Data type for key-value cache. Default is "auto". - trust_remote_code (bool): Whether to trust remote code when loading the model. Default is True. - context_length (Optional[int]): Maximum context length for the model. If None, uses the - model's default. - quantization (Optional[str]): Quantization method to use. If None, no quantization is applied. - served_model_name (Optional[str]): Name of the served model if using a model server. - chat_template (Optional[str]): Custom chat template to use for formatting inputs. - is_embedding (bool): Whether the model is used for embeddings. Default is False. - - Runtime parameters: - - extra_kwargs: Additional dictionary of keyword arguments that will be passed to the - SGLang class. - - structured_output: The structured output format to use across all the generations. - - log_level: The log level to use for the SGLang server. - - Examples: - Generate text: - - ```python - from distilabel.llms import SGLang - - llm = SGLang(model="your-model-name") - llm.load() - - output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]]) - ``` - """ - - model: str - dtype: str = "auto" - trust_remote_code: bool = False - quantization: Optional[str] = None - - tokenizer: Optional[str] = None - tokenizer_mode: str = "auto" - skip_tokenizer_init: bool = False - chat_template: Optional[str] = None - - load_format: str = "auto" - kv_cache_dtype: str = "auto" - context_length: Optional[int] = None - served_model_name: Optional[str] = None - is_embedding: bool = False - - seed: int = 0 - - extra_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field( - default_factory=dict, - description="Additional dictionary of keyword arguments that will be passed to the" - " `SGLang` class.", - ) - structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field( - default=None, - description="The structured output format to use across all the generations.", - ) - log_level: Optional[RuntimeParameter[str]] = Field( - default="error", - description="The log level to use for the SGLang server.", - ) - - _model: "Runtime" = PrivateAttr(None) - _tokenizer: "PreTrainedTokenizer" = PrivateAttr(None) - _structured_output_logits_processor: Optional[Callable] = PrivateAttr(default=None) + def __init__( + self, + model: str, + log_level: str = "error", + tensor_parallel_size: int = 1, + **kwargs, + ): + """Initialize SGLang Runtime LLM. - def load(self) -> None: - """ - Loads the SGLang model using either path or Huggingface repository id. - Additionally, this method also sets the `chat_template` for the tokenizer, so as to properly - parse the list of OpenAI formatted inputs using the expected format by the model, otherwise, the - default value is ChatML format, unless explicitly provided. + Args: + model: Model path or name + log_level: Logging level (default: "error") + tensor_parallel_size: Number of GPUs for tensor parallelism (default: 1) + **kwargs: Additional arguments passed to SGLang Runtime """ - super().load() - CudaDevicePlacementMixin.load(self) - - try: - from sglang.srt.server import Runtime - except ImportError as ie: - raise ImportError( - '`SGLang` is not installed. Please install it using `pip install "sglang[all]"`.' - " Also, install FlashInfer CUDA kernels using:\n" - "`pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/`" - ) from ie - - self._model = Runtime( - model_path=self.model, - dtype=self.dtype, - trust_remote_code=self.trust_remote_code, - quantization=self.quantization, - tokenizer_path=self.tokenizer, - tokenizer_mode=self.tokenizer_mode, - skip_tokenizer_init=self.skip_tokenizer_init, - load_format=self.load_format, - kv_cache_dtype=self.kv_cache_dtype, - context_length=self.context_length, - served_model_name=self.served_model_name, - is_embedding=self.is_embedding, - random_seed=self.seed, - **self.extra_kwargs, + super().__init__() + from sglang.srt.server import Runtime + + self._runtime = Runtime( + model_path=model, + log_level=log_level, + tp_size=tensor_parallel_size, + **kwargs, ) - - self._tokenizer = self._model.get_tokenizer() # type: ignore - if self.chat_template is not None: - self._tokenizer.chat_template = self.chat_template # type: ignore - - if self.structured_output: - self._structured_output_logits_processor = self._prepare_structured_output( - self.structured_output - ) - - def unload(self) -> None: - """Unloads the SGLang model.""" - self._model.shutdown() - self._model = None - self._tokenizer = None - CudaDevicePlacementMixin.unload(self) - super().unload() + self._model = model @property def model_name(self) -> str: - """Returns the model name used for the LLM.""" - return self.served_model_name - - def prepare_input(self, input: "StandardInput") -> str: - """Prepares the input (applying the chat template and tokenization) for the provided - input. - - Args: - input: the input list containing chat items. - - Returns: - The prompt to send to the LLM. - """ - if self._tokenizer.chat_template is None: - return input[0]["content"] - - prompt: str = ( - self._tokenizer.apply_chat_template( - input, # type: ignore - tokenize=False, - add_generation_prompt=True, # type: ignore - ) - if input - else "" - ) - return super().apply_magpie_pre_query_template(prompt, input) - - def _prepare_batches( - self, inputs: List[FormattedInput] - ) -> Tuple[List[List[FormattedInput]], List[int]]: - """Prepares the inputs by grouping them by the structured output. - - When we generate structured outputs with schemas obtained from a dataset, we need to - prepare the data to try to send batches of inputs instead of single inputs to the model - to take advante of the engine. So we group the inputs by the structured output to be - passed in the `generate` method. - - Args: - inputs: The batch of inputs passed to the generate method. As we expect to be generating - structured outputs, each element will be a tuple containing the instruction and the - structured output. - - Returns: - The prepared batches (sub-batches let's say) to be passed to the `generate` method. - Each new tuple will contain instead of the single instruction, a list of instructions - """ - instruction_order = {} - batches = {} - for i, (instruction, structured_output) in enumerate(inputs): - instruction = self.prepare_input(instruction) - instruction_order[instruction] = i - structured_output = json.dumps(structured_output) - if structured_output not in batches: - batches[structured_output] = [instruction] - else: - batches[structured_output].append(instruction) - - # Flatten the instructions in prepared_data - flat_instructions = [ - instruction for _, group in batches.items() for instruction in group - ] - # Generate the list of indices based on the original order - sorted_indices = [ - instruction_order[instruction] for instruction in flat_instructions - ] - return [ - (batch, json.loads(schema)) for schema, batch in batches.items() - ], sorted_indices + return self._model @validate_call - def generate( # type: ignore + def generate( self, - inputs: List[FormattedInput], + inputs: List[ChatType], num_generations: int = 1, - max_new_tokens: int = 128, - presence_penalty: float = 0.0, - frequency_penalty: float = 0.0, - repetition_penalty: float = 1.0, - temperature: float = 1.0, + temperature: float = 0.7, + max_tokens: int = 512, top_p: float = 1.0, - top_k: int = -1, - min_p: float = 0.0, - stop: Optional[List[str]] = None, - stop_token_ids: Optional[List[int]] = None, - logits_processors: Optional[LogitsProcessors] = None, - extra_sampling_params: Optional[Dict[str, Any]] = None, + **kwargs: Any, ) -> List[GenerateOutput]: - """Generates `num_generations` responses for each input. + """Generate completions for the input prompts. Args: - inputs: a list of inputs in chat format to generate responses for. - num_generations: the number of generations to create per input. Defaults to - `1`. - max_new_tokens: the maximum number of new tokens that the model will generate. - Defaults to `128`. - presence_penalty: the presence penalty to use for the generation. Defaults to - `0.0`. - frequency_penalty: the repetition penalty to use for the generation. Defaults - to `0.0`. - repetition_penalty: the repetition penalty to use for the generation Defaults to - `1.0`. - temperature: the temperature to use for the generation. Defaults to `0.1`. - top_p: the top-p value to use for the generation. Defaults to `1.0`. - top_k: the top-k value to use for the generation. Defaults to `0`. - min_p: the minimum probability to use for the generation. Defaults to `0.0`. - stop: a list of strings that will be used to stop the generation when found. - Defaults to `None`. - stop_token_ids: a list of token ids that will be used to stop the generation - when found. Defaults to `None`. - include_stop_str_in_output: whether to include the stop string in the output. - Defaults to `False`. - logits_processors: a list of functions to process the logits before sampling. - Defaults to `None`. - extra_sampling_params: dictionary with additional arguments to be passed to - the `SamplingParams` class from `vllm`. + inputs: List of chat messages + num_generations: Number of generations per prompt + temperature: Sampling temperature + max_tokens: Maximum number of tokens to generate + top_p: Nucleus sampling threshold + **kwargs: Additional sampling parameters Returns: - A list of lists of strings containing the generated responses for each input. + List of generation outputs """ - from sglang.srt.sampling.sampling_params import SamplingParams - - if not logits_processors: - logits_processors = [] - - if extra_sampling_params is None: - extra_sampling_params = {} - - structured_output = None - - if isinstance(inputs[0], tuple): - prepared_batches, sorted_indices = self._prepare_batches(inputs) - else: - # Simulate a batch without the structured output content - prepared_batches = [([self.prepare_input(input) for input in inputs], None)] - sorted_indices = None - - # Case in which we have a single structured output for the dataset - # Case in which we have a single structured output for the dataset - if self._structured_output_logits_processor: - logits_processors.append(self._structured_output_logits_processor) - - batched_outputs = [] - - for prepared_inputs, structured_output in prepared_batches: - if structured_output: - logits_processors.append( - self._prepare_structured_output(structured_output) - ) - - sampling_params = SamplingParams( # type: ignore - n=num_generations, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - repetition_penalty=repetition_penalty, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - max_new_tokens=max_new_tokens, - stop=stop, - stop_token_ids=[] if stop_token_ids is None else stop_token_ids, - **extra_sampling_params, + # Convert chat messages to prompt string + prompts = [] + for messages in inputs: + prompt = "" + for msg in messages: + role = msg["role"] + content = msg["content"] + if role == "system": + prompt += f"System: {content}\n" + elif role == "user": + prompt += f"User: {content}\n" + elif role == "assistant": + prompt += f"Assistant: {content}\n" + prompts.append(prompt.strip()) + + # Set up sampling parameters + sampling_params = { + "temperature": temperature, + "max_new_tokens": max_tokens, + "top_p": top_p, + **kwargs, + } + + # Generate completions + outputs = [] + for _ in range(num_generations): + response = self._runtime.generate( + prompt=prompts, sampling_params=sampling_params ) + parsed = json.loads(response) + + for completion in parsed: + outputs.append( + { + "text": completion["text"], + "tokens": completion.get("token_ids", []), + "logprobs": completion.get("logprobs", None), + } + ) - batch_outputs = self._model.generate( - prepared_inputs, - sampling_params.to_srt_kwargs(), - ) - - batched_outputs += [output["text"] for output in json.loads(batch_outputs)] - - # If logits_processor is set, we need to sort the outputs back to the original order - # (would be needed only if we have multiple structured outputs in the dataset) - if sorted_indices is not None: - batched_outputs = _sort_batches( - batched_outputs, sorted_indices, num_generations=num_generations - ) - return batched_outputs + return outputs - def _prepare_structured_output( - self, structured_output: Optional[OutlinesStructuredOutputType] = None - ) -> Union[Callable, None]: - """Creates the appropriate function to filter tokens to generate structured outputs. + def get_last_hidden_state(self, inputs: List[ChatType]) -> List[HiddenState]: + """Get hidden states for input prompts. Args: - structured_output: the configuration dict to prepare the structured output. + inputs: List of chat messages Returns: - The callable that will be used to guide the generation of the model. + List of hidden states """ - from distilabel.steps.tasks.structured_outputs.outlines import ( - prepare_guided_output, - ) - - result = prepare_guided_output(structured_output, "vllm", self._model) - if (schema := result.get("schema")) and self.structured_output: - self.structured_output["schema"] = schema - return result["processor"] - - -def _sort_batches( - batches: List[List[FormattedInput]], indices: List[int], num_generations: int = 1 -) -> List[str]: - """Helper function to sort back the mini-batches generated by the model. - - It must take into account the number of `num_generations` to repeat the indices - accordingly. - - Args: - batches: The mini-batches generated by the model. - indices: The indices that would sort the mini-batches back to the original order. - num_generations: The number of generations requested to vLLM. Defaults to 1. - - Returns: - Sorted batched_outputs. - """ - batch_sizes = [len(batch) for batch in batches] - flattened_batches = np.array([b for batch in batches for b in batch]) - sorted_batches = np.take_along_axis( - flattened_batches, - np.argsort(np.repeat(indices, num_generations)), - axis=0, - ).tolist() - sorted_batches = _batchify(sorted_batches, batch_sizes) - return sorted_batches - - -def _batchify(sorted_batches: List[str], batch_sizes: List[int]) -> List[List[str]]: - """Helper function to regenerate the sorted batches from the flattened sorted ones. - - Args: - sorted_batches: Output obtained from the `_sort_batches` function. - batch_sizes: The batch sizes to be used to split the sorted batches. - - Returns: - Batched sorted batches in the original shape. - """ - batches = [] - idx = 0 - for bs in batch_sizes: - batches.append(sorted_batches[idx : idx + bs]) - idx += bs - return batches - + raise NotImplementedError("Hidden state extraction not supported for SGLang") -# You can add a ClientSGLang class here if needed, similar to ClientvLLM + def __del__(self): + """Cleanup runtime when object is deleted.""" + if hasattr(self, "runtime"): + self.runtime.shutdown() From a4d429279711f3312016feacf86cc098f025a214 Mon Sep 17 00:00:00 2001 From: bikash Date: Sat, 26 Oct 2024 15:35:24 +0000 Subject: [PATCH 09/10] Breaking changes --- src/distilabel/models/llms/sglang.py | 515 ++++++++++++++++++++++----- 1 file changed, 428 insertions(+), 87 deletions(-) diff --git a/src/distilabel/models/llms/sglang.py b/src/distilabel/models/llms/sglang.py index 3d8a21dc5..a872b51d6 100644 --- a/src/distilabel/models/llms/sglang.py +++ b/src/distilabel/models/llms/sglang.py @@ -13,128 +13,469 @@ # limitations under the License. import json -from typing import TYPE_CHECKING, Any, List +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Literal, + Optional, + Tuple, + Union, +) -from pydantic import PrivateAttr, validate_call +import numpy as np +from pydantic import Field, PrivateAttr, validate_call +from distilabel.mixins.runtime_parameters import RuntimeParameter from distilabel.models.llms import LLM -from distilabel.models.llms.typing import GenerateOutput, HiddenState -from distilabel.steps.tasks.typing import ChatType +from distilabel.models.llms.typing import GenerateOutput +from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin +from distilabel.models.mixins.magpie import MagpieChatTemplateMixin +from distilabel.steps.tasks.typing import ( + FormattedInput, + OutlinesStructuredOutputType, +) if TYPE_CHECKING: from sglang.srt.server import Runtime + from transformers import PreTrainedTokenizer + from distilabel.steps.tasks.typing import StandardInput -class SGLangLLM(LLM): - _runtime: "Runtime" = PrivateAttr(None) +LogitsProcessorFn = Union[ + Callable[[List[int], Any], Any], + Callable[[List[int], List[int], Any], Any], +] - def __init__( - self, - model: str, - log_level: str = "error", - tensor_parallel_size: int = 1, - **kwargs, - ): - """Initialize SGLang Runtime LLM. +LogitsProcessors = List[LogitsProcessorFn] - Args: - model: Model path or name - log_level: Logging level (default: "error") - tensor_parallel_size: Number of GPUs for tensor parallelism (default: 1) - **kwargs: Additional arguments passed to SGLang Runtime + +class SGLangLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): + """SGLang library LLM implementation. + + Attributes: + model: the model Hugging Face Hub repo id or a path to a directory containing the + model weights and configuration files. + dtype: the data type to use for the model. Defaults to `auto`. + trust_remote_code: whether to trust the remote code when loading the model. Defaults + to `False`. + quantization: the quantization mode to use for the model. Defaults to `None`. + revision: the revision of the model to load. Defaults to `None`. + tokenizer: the tokenizer Hugging Face Hub repo id or a path to a directory containing + the tokenizer files. If not provided, the tokenizer will be loaded from the + model directory. Defaults to `None`. + tokenizer_mode: the mode to use for the tokenizer. Defaults to `auto`. + tokenizer_revision: the revision of the tokenizer to load. Defaults to `None`. + skip_tokenizer_init: whether to skip the initialization of the tokenizer. Defaults + to `False`. + chat_template: a chat template that will be used to build the prompts before + sending them to the model. If not provided, the chat template defined in the + tokenizer config will be used. If not provided and the tokenizer doesn't have + a chat template, then ChatML template will be used. Defaults to `None`. + structured_output: a dictionary containing the structured output configuration or if more + fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None. + seed: the seed to use for the random number generator. Defaults to `0`. + extra_kwargs: additional dictionary of keyword arguments that will be passed to the + `LLM` class of `vllm` library. Defaults to `{}`. + _model: the `vLLM` model instance. This attribute is meant to be used internally + and should not be accessed directly. It will be set in the `load` method. + _tokenizer: the tokenizer instance used to format the prompt before passing it to + the `LLM`. This attribute is meant to be used internally and should not be + accessed directly. It will be set in the `load` method. + use_magpie_template: a flag used to enable/disable applying the Magpie pre-query + template. Defaults to `False`. + magpie_pre_query_template: the pre-query template to be applied to the prompt or + sent to the LLM to generate an instruction or a follow up user message. Valid + values are "llama3", "qwen2" or another pre-query template provided. Defaults + to `None`. + + References: + https://github.com/sgl-project/sglang + + Runtime parameters: + - `extra_kwargs`: additional dictionary of keyword arguments that will be passed to + the `Runtime` class of `sglang` library. + + Examples: + Generate text: + + ```python + from distilabel.models.llms import SGLangLLM + + llm = llm = SGLangLLM( + model="mistralai/Mistral-7B-Instruct-v0.2", + tensor_parallel_size=1, # Using single GPU + log_level="info" # Set to "info" to see SGLang's logs + ) + llm.load() + + # Call the model + test_inputs: List[List[dict]] = [ + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"} + ] + ] + outputs = llm.generate( + inputs=test_inputs, + num_generations=1, + temperature=0.7, + max_tokens=100 + ) + ``` + + Generate structured data: TODO + + ```python + from distilabel.models.llms import SGLangLLM + + llm = llm = SGLangLLM( + model="mistralai/Mistral-7B-Instruct-v0.2", + tensor_parallel_size=1, # Using single GPU + log_level="info" # Set to "info" to see SGLang's logs + ) + llm.load() + + # Call the model + test_inputs: List[List[dict]] = [ + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the capital of France?"} + ] + ] + outputs = llm.generate( + inputs=test_inputs, + num_generations=1, + temperature=0.7, + max_tokens=100 + ) + ``` + """ + + model: str + dtype: str = "auto" + trust_remote_code: bool = False + quantization: Optional[str] = None + revision: Optional[str] = None + + tokenizer: Optional[str] = None + tokenizer_mode: Literal["auto", "slow"] = "auto" + tokenizer_revision: Optional[str] = None + skip_tokenizer_init: bool = False + chat_template: Optional[str] = None + + seed: int = 0 + + extra_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field( + default_factory=dict, + description="Additional dictionary of keyword arguments that will be passed to the" + " `vLLM` class of `vllm` library. See all the supported arguments at: " + "https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py", + ) + structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field( + default=None, + description="The structured output format to use across all the generations.", + ) + + _runtime: "Runtime" = PrivateAttr(None) + _tokenizer: "PreTrainedTokenizer" = PrivateAttr(None) + _structured_output_logits_processor: Optional[Callable] = PrivateAttr(default=None) + + def load(self) -> None: + """Loads the `sglang` model using either the path or the Hugging Face Hub repository id. + Additionally, this method also sets the `chat_template` for the tokenizer, so as to properly + parse the list of OpenAI formatted inputs using the expected format by the model, otherwise, the + default value is ChatML format, unless explicitly provided. """ - super().__init__() - from sglang.srt.server import Runtime + super().load() + CudaDevicePlacementMixin.load(self) + + try: + from sglang.srt.server import Runtime + except ImportError as ie: + raise ImportError( + 'sglang is not installed. Please install it using `pip install "sglang[all]"`.' + ) from ie self._runtime = Runtime( - model_path=model, - log_level=log_level, - tp_size=tensor_parallel_size, - **kwargs, + model_path=self.model, + ldtype=self.dtype, + trust_remote_code=self.trust_remote_code, + quantization=self.quantization, + revision=self.revision, + tokenizer=self.tokenizer, + tokenizer_mode=self.tokenizer_mode, + tokenizer_revision=self.tokenizer_revision, + skip_tokenizer_init=self.skip_tokenizer_init, + seed=self.seed, + **self.extra_kwargs, ) - self._model = model + self._tokenizer = self._model.get_tokenizer() # type: ignore + if self.chat_template is not None: + self._tokenizer.chat_template = self.chat_template # type: ignore + + if self.structured_output: + self._structured_output_logits_processor = self._prepare_structured_output( + self.structured_output + ) + + def unload(self) -> None: + """Unloads the `sglang` model.""" + self._runtime.shutdown() + self._runtime = None # type: ignore + self._tokenizer = None # type: ignore + CudaDevicePlacementMixin.unload(self) + super().unload() @property def model_name(self) -> str: + """Returns the model name used for the LLM.""" return self._model + def prepare_input(self, input: "StandardInput") -> str: + """Prepares the input (applying the chat template and tokenization) for the provided + input. + + Args: + input: the input list containing chat items. + + Returns: + The prompt to send to the LLM. + """ + if self._tokenizer.chat_template is None: + return input[0]["content"] + + prompt: str = ( + self._tokenizer.apply_chat_template( + input, # type: ignore + tokenize=False, + add_generation_prompt=True, # type: ignore + ) + if input + else "" + ) + return super().apply_magpie_pre_query_template(prompt, input) + + def _prepare_batches( + self, inputs: List[FormattedInput] + ) -> Tuple[List[List[FormattedInput]], List[int]]: + """Prepares the inputs by grouping them by the structured output. + + When we generate structured outputs with schemas obtained from a dataset, we need to + prepare the data to try to send batches of inputs instead of single inputs to the model + to take advante of the engine. So we group the inputs by the structured output to be + passed in the `generate` method. + + Args: + inputs: The batch of inputs passed to the generate method. As we expect to be generating + structured outputs, each element will be a tuple containing the instruction and the + structured output. + + Returns: + The prepared batches (sub-batches let's say) to be passed to the `generate` method. + Each new tuple will contain instead of the single instruction, a list of instructions + """ + instruction_order = {} + batches = {} + for i, (instruction, structured_output) in enumerate(inputs): + instruction = self.prepare_input(instruction) + instruction_order[instruction] = i + structured_output = json.dumps(structured_output) + if structured_output not in batches: + batches[structured_output] = [instruction] + else: + batches[structured_output].append(instruction) + + # Flatten the instructions in prepared_data + flat_instructions = [ + instruction for _, group in batches.items() for instruction in group + ] + # Generate the list of indices based on the original order + sorted_indices = [ + instruction_order[instruction] for instruction in flat_instructions + ] + return [ + (batch, json.loads(schema)) for schema, batch in batches.items() + ], sorted_indices + @validate_call - def generate( + def generate( # type: ignore self, - inputs: List[ChatType], + inputs: List[FormattedInput], num_generations: int = 1, - temperature: float = 0.7, - max_tokens: int = 512, + max_new_tokens: int = 128, + presence_penalty: float = 0.0, + frequency_penalty: float = 0.0, + repetition_penalty: float = 1.0, + temperature: float = 1.0, top_p: float = 1.0, - **kwargs: Any, + top_k: int = -1, + min_p: float = 0.0, + stop: Optional[List[str]] = None, + stop_token_ids: Optional[List[int]] = None, + include_stop_str_in_output: bool = False, + logits_processors: Optional[LogitsProcessors] = None, + extra_sampling_params: Optional[Dict[str, Any]] = None, ) -> List[GenerateOutput]: - """Generate completions for the input prompts. + """Generates `num_generations` responses for each input. Args: - inputs: List of chat messages - num_generations: Number of generations per prompt - temperature: Sampling temperature - max_tokens: Maximum number of tokens to generate - top_p: Nucleus sampling threshold - **kwargs: Additional sampling parameters + inputs: a list of inputs in chat format to generate responses for. + num_generations: the number of generations to create per input. Defaults to + `1`. + max_new_tokens: the maximum number of new tokens that the model will generate. + Defaults to `128`. + presence_penalty: the presence penalty to use for the generation. Defaults to + `0.0`. + frequency_penalty: the repetition penalty to use for the generation. Defaults + to `0.0`. + repetition_penalty: the repetition penalty to use for the generation Defaults to + `1.0`. + temperature: the temperature to use for the generation. Defaults to `0.1`. + top_p: the top-p value to use for the generation. Defaults to `1.0`. + top_k: the top-k value to use for the generation. Defaults to `0`. + min_p: the minimum probability to use for the generation. Defaults to `0.0`. + stop: a list of strings that will be used to stop the generation when found. + Defaults to `None`. + stop_token_ids: a list of token ids that will be used to stop the generation + when found. Defaults to `None`. + include_stop_str_in_output: whether to include the stop string in the output. + Defaults to `False`. + logits_processors: a list of functions to process the logits before sampling. + Defaults to `None`. + extra_sampling_params: dictionary with additional arguments to be passed to + the `SamplingParams` class from `vllm`. Returns: - List of generation outputs + A list of lists of strings containing the generated responses for each input. """ - # Convert chat messages to prompt string - prompts = [] - for messages in inputs: - prompt = "" - for msg in messages: - role = msg["role"] - content = msg["content"] - if role == "system": - prompt += f"System: {content}\n" - elif role == "user": - prompt += f"User: {content}\n" - elif role == "assistant": - prompt += f"Assistant: {content}\n" - prompts.append(prompt.strip()) - - # Set up sampling parameters - sampling_params = { - "temperature": temperature, - "max_new_tokens": max_tokens, - "top_p": top_p, - **kwargs, - } - - # Generate completions - outputs = [] - for _ in range(num_generations): - response = self._runtime.generate( - prompt=prompts, sampling_params=sampling_params - ) - parsed = json.loads(response) - - for completion in parsed: - outputs.append( - { - "text": completion["text"], - "tokens": completion.get("token_ids", []), - "logprobs": completion.get("logprobs", None), - } + from sglang.srt.server import SamplingParams + + if not logits_processors: + logits_processors = [] + + if extra_sampling_params is None: + extra_sampling_params = {} + + structured_output = None + + if isinstance(inputs[0], tuple): + prepared_batches, sorted_indices = self._prepare_batches(inputs) + else: + # Simulate a batch without the structured output content + prepared_batches = [([self.prepare_input(input) for input in inputs], None)] + sorted_indices = None + + # Case in which we have a single structured output for the dataset + if self._structured_output_logits_processor: + logits_processors.append(self._structured_output_logits_processor) + + batched_outputs = [] + + for prepared_inputs, structured_output in prepared_batches: + if structured_output: + logits_processors.append( + self._prepare_structured_output(structured_output) ) - return outputs + sampling_params = SamplingParams( # type: ignore + n=num_generations, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + repetition_penalty=repetition_penalty, + temperature=temperature, + top_p=top_p, + top_k=top_k, + min_p=min_p, + max_tokens=max_new_tokens, + stop=stop, + stop_token_ids=stop_token_ids, + include_stop_str_in_output=include_stop_str_in_output, + logits_processors=logits_processors, + **extra_sampling_params, + ) - def get_last_hidden_state(self, inputs: List[ChatType]) -> List[HiddenState]: - """Get hidden states for input prompts. + batch_outputs = self._model.generate( + prepared_inputs, + sampling_params, + use_tqdm=False, # type: ignore + ) + + batched_outputs += [ + [output.text for output in outputs.outputs] for outputs in batch_outputs + ] + + # If logits_processor is set, we need to sort the outputs back to the original order + # (would be needed only if we have multiple structured outputs in the dataset) + if sorted_indices is not None: + batched_outputs = _sort_batches( + batched_outputs, sorted_indices, num_generations=num_generations + ) + return batched_outputs + + def _prepare_structured_output( + self, structured_output: Optional[OutlinesStructuredOutputType] = None + ) -> Union[Callable, None]: + """Creates the appropriate function to filter tokens to generate structured outputs. Args: - inputs: List of chat messages + structured_output: the configuration dict to prepare the structured output. Returns: - List of hidden states + The callable that will be used to guide the generation of the model. """ - raise NotImplementedError("Hidden state extraction not supported for SGLang") + from distilabel.steps.tasks.structured_outputs.outlines import ( + prepare_guided_output, + ) + + result = prepare_guided_output(structured_output, "vllm", self._model) + if (schema := result.get("schema")) and self.structured_output: + self.structured_output["schema"] = schema + return result["processor"] + + +def _sort_batches( + batches: List[List[FormattedInput]], indices: List[int], num_generations: int = 1 +) -> List[str]: + """Helper function to sort back the mini-batches generated by the model. + + It must take into account the number of `num_generations` to repeat the indices + accordingly. + + Args: + batches: The mini-batches generated by the model. + indices: The indices that would sort the mini-batches back to the original order. + num_generations: The number of generations requested to vLLM. Defaults to 1. + + Returns: + Sorted batched_outputs. + """ + batch_sizes = [len(batch) for batch in batches] + flattened_batches = np.array([b for batch in batches for b in batch]) + sorted_batches = np.take_along_axis( + flattened_batches, + np.argsort(np.repeat(indices, num_generations)), + axis=0, + ).tolist() + sorted_batches = _batchify(sorted_batches, batch_sizes) + return sorted_batches + + +def _batchify(sorted_batches: List[str], batch_sizes: List[int]) -> List[List[str]]: + """Helper function to regenerate the sorted batches from the flattened sorted ones. + + Args: + sorted_batches: Output obtained from the `_sort_batches` function. + batch_sizes: The batch sizes to be used to split the sorted batches. - def __del__(self): - """Cleanup runtime when object is deleted.""" - if hasattr(self, "runtime"): - self.runtime.shutdown() + Returns: + Batched sorted batches in the original shape. + """ + batches = [] + idx = 0 + for bs in batch_sizes: + batches.append(sorted_batches[idx : idx + bs]) + idx += bs + return batches From fa3fb35b6022a3bc61a60015ee9ac781b992d4f0 Mon Sep 17 00:00:00 2001 From: bikash Date: Sat, 26 Oct 2024 16:44:14 +0000 Subject: [PATCH 10/10] cleaned and working SGLang impl --- src/distilabel/models/llms/sglang.py | 71 ++++++++-------------------- 1 file changed, 20 insertions(+), 51 deletions(-) diff --git a/src/distilabel/models/llms/sglang.py b/src/distilabel/models/llms/sglang.py index a872b51d6..8b8b7f2d6 100644 --- a/src/distilabel/models/llms/sglang.py +++ b/src/distilabel/models/llms/sglang.py @@ -67,7 +67,6 @@ class SGLangLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): the tokenizer files. If not provided, the tokenizer will be loaded from the model directory. Defaults to `None`. tokenizer_mode: the mode to use for the tokenizer. Defaults to `auto`. - tokenizer_revision: the revision of the tokenizer to load. Defaults to `None`. skip_tokenizer_init: whether to skip the initialization of the tokenizer. Defaults to `False`. chat_template: a chat template that will be used to build the prompts before @@ -106,8 +105,6 @@ class SGLangLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): llm = llm = SGLangLLM( model="mistralai/Mistral-7B-Instruct-v0.2", - tensor_parallel_size=1, # Using single GPU - log_level="info" # Set to "info" to see SGLang's logs ) llm.load() @@ -133,8 +130,6 @@ class SGLangLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): llm = llm = SGLangLLM( model="mistralai/Mistral-7B-Instruct-v0.2", - tensor_parallel_size=1, # Using single GPU - log_level="info" # Set to "info" to see SGLang's logs ) llm.load() @@ -159,10 +154,8 @@ class SGLangLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): trust_remote_code: bool = False quantization: Optional[str] = None revision: Optional[str] = None - tokenizer: Optional[str] = None tokenizer_mode: Literal["auto", "slow"] = "auto" - tokenizer_revision: Optional[str] = None skip_tokenizer_init: bool = False chat_template: Optional[str] = None @@ -179,7 +172,7 @@ class SGLangLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): description="The structured output format to use across all the generations.", ) - _runtime: "Runtime" = PrivateAttr(None) + _model: "Runtime" = PrivateAttr(None) _tokenizer: "PreTrainedTokenizer" = PrivateAttr(None) _structured_output_logits_processor: Optional[Callable] = PrivateAttr(default=None) @@ -199,17 +192,14 @@ def load(self) -> None: 'sglang is not installed. Please install it using `pip install "sglang[all]"`.' ) from ie - self._runtime = Runtime( + self._model = Runtime( model_path=self.model, - ldtype=self.dtype, + dtype=self.dtype, trust_remote_code=self.trust_remote_code, quantization=self.quantization, - revision=self.revision, - tokenizer=self.tokenizer, + tokenizer_path=self.tokenizer, tokenizer_mode=self.tokenizer_mode, - tokenizer_revision=self.tokenizer_revision, skip_tokenizer_init=self.skip_tokenizer_init, - seed=self.seed, **self.extra_kwargs, ) self._tokenizer = self._model.get_tokenizer() # type: ignore @@ -223,8 +213,8 @@ def load(self) -> None: def unload(self) -> None: """Unloads the `sglang` model.""" - self._runtime.shutdown() - self._runtime = None # type: ignore + self._model.shutdown() + self._model = None # type: ignore self._tokenizer = None # type: ignore CudaDevicePlacementMixin.unload(self) super().unload() @@ -315,7 +305,6 @@ def generate( # type: ignore min_p: float = 0.0, stop: Optional[List[str]] = None, stop_token_ids: Optional[List[int]] = None, - include_stop_str_in_output: bool = False, logits_processors: Optional[LogitsProcessors] = None, extra_sampling_params: Optional[Dict[str, Any]] = None, ) -> List[GenerateOutput]: @@ -341,8 +330,6 @@ def generate( # type: ignore Defaults to `None`. stop_token_ids: a list of token ids that will be used to stop the generation when found. Defaults to `None`. - include_stop_str_in_output: whether to include the stop string in the output. - Defaults to `False`. logits_processors: a list of functions to process the logits before sampling. Defaults to `None`. extra_sampling_params: dictionary with additional arguments to be passed to @@ -351,7 +338,6 @@ def generate( # type: ignore Returns: A list of lists of strings containing the generated responses for each input. """ - from sglang.srt.server import SamplingParams if not logits_processors: logits_processors = [] @@ -366,54 +352,37 @@ def generate( # type: ignore else: # Simulate a batch without the structured output content prepared_batches = [([self.prepare_input(input) for input in inputs], None)] - sorted_indices = None # Case in which we have a single structured output for the dataset if self._structured_output_logits_processor: logits_processors.append(self._structured_output_logits_processor) - batched_outputs = [] - for prepared_inputs, structured_output in prepared_batches: if structured_output: logits_processors.append( self._prepare_structured_output(structured_output) ) - sampling_params = SamplingParams( # type: ignore - n=num_generations, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - repetition_penalty=repetition_penalty, - temperature=temperature, - top_p=top_p, - top_k=top_k, - min_p=min_p, - max_tokens=max_new_tokens, - stop=stop, - stop_token_ids=stop_token_ids, - include_stop_str_in_output=include_stop_str_in_output, - logits_processors=logits_processors, + sampling_params = { + "n": num_generations, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + "repetition_penalty": repetition_penalty, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "min_p": min_p, + "max_new_tokens": max_new_tokens, + "stop": stop, + "stop_token_ids": stop_token_ids, **extra_sampling_params, - ) + } batch_outputs = self._model.generate( prepared_inputs, sampling_params, - use_tqdm=False, # type: ignore - ) - - batched_outputs += [ - [output.text for output in outputs.outputs] for outputs in batch_outputs - ] - - # If logits_processor is set, we need to sort the outputs back to the original order - # (would be needed only if we have multiple structured outputs in the dataset) - if sorted_indices is not None: - batched_outputs = _sort_batches( - batched_outputs, sorted_indices, num_generations=num_generations ) - return batched_outputs + return batch_outputs def _prepare_structured_output( self, structured_output: Optional[OutlinesStructuredOutputType] = None