Skip to content

Commit

Permalink
Fix vllm sorting mechanism and add mocked generate method to the test…
Browse files Browse the repository at this point in the history
… suite
  • Loading branch information
plaguss committed Oct 18, 2024
1 parent 6d19de7 commit 9746d75
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 74 deletions.
78 changes: 69 additions & 9 deletions src/distilabel/llms/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def prepare_input(self, input: "StandardInput") -> str:
The prompt to send to the LLM.
"""
if self._tokenizer.chat_template is None:
return input[0]["content"]
return [item["content"] for item in input if item["role"] == "user"][0]

prompt: str = (
self._tokenizer.apply_chat_template(
Expand Down Expand Up @@ -271,7 +271,14 @@ def _prepare_batches(
batches = {}
for i, (instruction, structured_output) in enumerate(inputs):
instruction = self.prepare_input(instruction)
instruction_order[instruction] = i

# We need to convert the instruction to a string to make it hashable
str_instruction = instruction
if not isinstance(instruction, str):
str_instruction = json.dumps(instruction)

instruction_order[str_instruction] = i

structured_output = json.dumps(structured_output)
if structured_output not in batches:
batches[structured_output] = [instruction]
Expand All @@ -284,7 +291,7 @@ def _prepare_batches(
]
# Generate the list of indices based on the original order
sorted_indices = [
instruction_order[instruction] for instruction in flat_instructions
instruction_order[str_instruction] for instruction in flat_instructions
]
return [
(batch, json.loads(schema)) for schema, batch in batches.items()
Expand Down Expand Up @@ -357,7 +364,6 @@ def generate( # type: ignore
# Simulate a batch without the structured output content
prepared_batches = [([self.prepare_input(input) for input in inputs], None)]
sorted_indices = None

# Case in which we have a single structured output for the dataset
if self._structured_output_logits_processor:
logits_processors.append(self._structured_output_logits_processor)
Expand Down Expand Up @@ -388,12 +394,15 @@ def generate( # type: ignore
**extra_sampling_params,
)

batch_outputs: "RequestOutputs" = self._model.generate(
batch_outputs: List["RequestOutputs"] = self._model.generate(
prepared_inputs,
sampling_params,
use_tqdm=False, # type: ignore
)

# TODO: This is repeated in prepare_output, but for simplicity we extract
# the batched_outputs as we did when there wasn't statistics and we just
# return the str generations
batched_outputs += [
[output.text for output in outputs.outputs] for outputs in batch_outputs
]
Expand All @@ -405,14 +414,16 @@ def generate( # type: ignore
)
)

# TODO: This must be updated with the statistics
# If logits_processor is set, we need to sort the outputs back to the original order
# (would be needed only if we have multiple structured outputs in the dataset)
if sorted_indices is not None:
batched_outputs = _sort_batches(
batched_outputs, sorted_indices, num_generations=num_generations
# Sort the batched outputs together with the statistics
generations = self._prepare_sorted_resuts(
batched_outputs,
sorted_indices,
generations,
num_generations=num_generations,
)
# return batched_outputs
return generations

def _prepare_structured_output(
Expand Down Expand Up @@ -445,6 +456,55 @@ def _get_llm_statistics(
"output_tokens": output_tokens,
}

@staticmethod
def _prepare_sorted_resuts(
batched_outputs: List[List[FormattedInput]],
sorted_indices: List[int],
generations: List[GenerateOutput],
num_generations: int = 1,
) -> List[GenerateOutput]:
"""Helper method to sort the results in case of multiple structured outputs in the dataset.
Args:
batched_outputs: The mini-batches generated by the model.
sorted_indices: The indices that would sort the mini-batches back to the original order.
generations: The prepared outputs that would be returned in the general case,
from which the statistics will be extracted and sorted.
num_generations: The number of generations requested to vLLM. Defaults to 1.
Returns:
The list of GenerateOutput sorted back to the original order.
"""

# This was the only required sort back with only the generations
batched_outputs = _sort_batches(
batched_outputs, sorted_indices, num_generations=num_generations
)
# Prepare the statistics to be sorted
# Loop over all the variables in the statistics
# Get the keys from the LLMStatistics
statistic_fields = list(generations[0]["statistics"].keys())
statistics = {}
for field in statistic_fields:
batched_field = _sort_batches(
[g["statistics"][field] for g in generations],
sorted_indices,
num_generations=num_generations,
)
statistics[field] = batched_field

# Regenerates the outputs as they are returned buy `preare_output`
sorted_results = []
for i, batched_output in enumerate(batched_outputs):
generation = {"generations": batched_output}
statistics = {
field: batched_field[i] for field, batched_field in statistics.items()
}
generation.update({"statistics": statistics})
sorted_results.append(generation)

return sorted_results


class ClientvLLM(OpenAILLM, MagpieChatTemplateMixin):
"""A client for the `vLLM` server implementing the OpenAI API specification.
Expand Down
140 changes: 75 additions & 65 deletions tests/unit/llms/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List
from typing import Any, Dict, List
from unittest import mock

import numpy as np
import pytest
from openai.pagination import SyncPage
from openai.types import Model
Expand All @@ -25,7 +24,7 @@
from pydantic import BaseModel

from distilabel.llms import vLLM
from distilabel.llms.vllm import ClientvLLM, _sort_batches
from distilabel.llms.vllm import ClientvLLM


class Character(BaseModel):
Expand Down Expand Up @@ -104,91 +103,102 @@ class Animal(BaseModel):

# Just a mock to avoid loading the model
class DummyTokenizer:
chat_template = None
# chat_template = None
chat_template = "template"

def __init__(self) -> None:
pass

def apply_chat_template(self, input, **kwargs):
return input

def encode(self, text: str):
return [1, 2, 3, 4, 5]


class TestvLLM:
@pytest.mark.parametrize("multi_structured_output", (False, True))
@pytest.mark.parametrize(
"num_generations, expected_sorted_batches",
"num_generations, expected_result",
[
(
1,
[
"Generate a character from a RPG game.",
"Generate an animal from a zoo.",
"Repeated character",
"What's the weather like today in Seattle in Celsius degrees?",
"Other character",
"repeated regex",
{
"generations": ["I'm fine thank you"],
"statistics": {"input_tokens": [5], "output_tokens": [6]},
}
],
),
(
3,
np.repeat(
[
"Generate a character from a RPG game.",
"Generate an animal from a zoo.",
"Repeated character",
"What's the weather like today in Seattle in Celsius degrees?",
"Other character",
"repeated regex",
],
3,
).tolist(),
2,
[
{
"generations": ["I'm fine thank you"] * 2,
"statistics": {"input_tokens": [5, 5], "output_tokens": [6, 6]},
}
],
),
],
)
def test_prepare_batches_and_sort_back(
self, num_generations: int, expected_sorted_batches: List[str]
):
formatted_inputs = [
(item["instruction"], item["structured_output"])
for row in SAMPLE_DATA
for item in row
]
def test_generate(
self,
multi_structured_output: bool,
num_generations: int,
expected_result: List[Dict[str, Any]],
) -> None:
llm = vLLM(model="dummy")
llm._tokenizer = DummyTokenizer()
batches, indices = llm._prepare_batches(formatted_inputs)
# NOTE: We have to simulate calling self._model.generate(n=num_generations) and then sorting the results
num_generations_batches = []
for batch in batches:
num_generations_batches.append(
(np.repeat(batch[0], num_generations).tolist(), batch[1])
vllm_mock = mock.MagicMock()
# mock the import by hacking sys.modules
# https://stackoverflow.com/questions/60919705/how-to-mock-in-a-python-unittest-a-library-not-installed-locally
import sys

if "vllm" not in sys.modules:
sys.modules["vllm"] = vllm_mock
llm._model = vllm_mock

mocked_requests_output = [
mock.Mock( # RequestOutput
outputs=[
mock.Mock( # CompletionOutput
text="I'm fine thank you",
token_ids=[1, 2, 3, 4, 5, 7],
)
]
* num_generations,
)
batches = num_generations_batches
# Recreate as the output from batched_outputs += [[output.text for output in outputs.outputs] for outputs in batch_outputs]
batches = [batch for batch, _ in batches]
sorted_batches = _sort_batches(
batches, indices, num_generations=num_generations
)
]

assert sorted_batches == [
np.repeat(
[
"Generate a character from a RPG game.",
"Generate an animal from a zoo.",
"Repeated character",
],
num_generations,
).tolist(),
np.repeat(
["What's the weather like today in Seattle in Celsius degrees?"],
num_generations,
).tolist(),
np.repeat(
llm._model.generate = mock.MagicMock(return_value=mocked_requests_output)
if not multi_structured_output:
formatted_inputs = [
[
"Other character",
"repeated regex",
],
num_generations,
).tolist(),
]
{"role": "system", "content": "sysprompt"},
{
"role": "user",
"content": "I'm fine thank you",
},
]
]
else:
formatted_inputs = [
(
[
{"role": "system", "content": "sysprompt"},
{
"role": "user",
"content": "I'm fine thank you",
},
],
{
"format": "json",
"schema": Character.model_json_schema(),
},
)
]
result = llm.generate(inputs=formatted_inputs, num_generations=num_generations)
assert result == expected_result


@mock.patch("openai.OpenAI")
Expand Down Expand Up @@ -256,7 +266,7 @@ async def test_agenerate(
assert generations == {
"generations": ["I'm fine thank you", "I'm fine thank you sir"],
"statistics": {
"input_tokens": 10,
"output_tokens": 10,
"input_tokens": [10],
"output_tokens": [10],
},
}

0 comments on commit 9746d75

Please sign in to comment.