Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Magpie and MagpieGenerator tasks #778

Merged
merged 32 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
e1b4c8b
Move `CudaDevicePlacementMixin` to new module
gabrielmbmb Jul 9, 2024
d557542
Initial work for implementing Magpie
gabrielmbmb Jul 9, 2024
755f7ec
Simplify magpie implementation
gabrielmbmb Jul 9, 2024
1719d15
Remove `use_open_ai` and add `MagpieChatTemplateMixin` to
gabrielmbmb Jul 10, 2024
775ca4e
Add `MagpieChatTemplateMixin` to `vLLM`
gabrielmbmb Jul 10, 2024
9ff6eeb
Add `MagpieGenerator` task
gabrielmbmb Jul 10, 2024
7b9bde5
Move `CudaDevicePlacementMixins` to new subpackage
gabrielmbmb Jul 10, 2024
dadac54
Fix unit tests
gabrielmbmb Jul 10, 2024
844ec57
Fix docstrings
gabrielmbmb Jul 10, 2024
04ecc3a
Mock `HF_TOKEN` environment variable
gabrielmbmb Jul 10, 2024
a15752a
Fix list index out of range
gabrielmbmb Jul 11, 2024
b46cd33
Fix `MagpieGenerator` last batch
gabrielmbmb Jul 11, 2024
75bd827
Add `only_instruction` attribute
gabrielmbmb Jul 11, 2024
a86f640
Update categories
gabrielmbmb Jul 11, 2024
463f622
testing
gabrielmbmb Jul 11, 2024
953b933
Worth trying
gabrielmbmb Jul 11, 2024
53ff036
Add examples
gabrielmbmb Jul 11, 2024
ba85907
Add magpie unit tests
gabrielmbmb Jul 11, 2024
b2e8805
Fix docstring
gabrielmbmb Jul 11, 2024
e52ae3f
Update docstrings
gabrielmbmb Jul 11, 2024
5736a25
Apply suggestions from code review
gabrielmbmb Jul 11, 2024
32b1725
Update to `huggingface_hub >= 0.22.0`
gabrielmbmb Jul 11, 2024
e899133
Add generation with `chat_completion`
gabrielmbmb Jul 12, 2024
91d05e2
Merge branch 'magpie' of https://github.com/argilla-io/rlxf into magpie
gabrielmbmb Jul 12, 2024
87371ec
Update `agenerate` arguments
gabrielmbmb Jul 15, 2024
bf350e3
Update unit tests
gabrielmbmb Jul 15, 2024
360433f
Fix `tools` were not being used
gabrielmbmb Jul 15, 2024
ef68210
Update unit tests
gabrielmbmb Jul 15, 2024
9ee7096
Fix list of tuples instead of list of list
gabrielmbmb Jul 15, 2024
3863eb3
Add missing docstring
gabrielmbmb Jul 15, 2024
87c11cc
Add `chat_completion` unit tests
gabrielmbmb Jul 15, 2024
cd3cc5d
Fix `GroqLLM.generate` unit test after updating `_agenerate`
gabrielmbmb Jul 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ anthropic = ["anthropic >= 0.20.0"]
argilla = ["argilla >= 1.29.0"]
cohere = ["cohere >= 5.2.0"]
groq = ["groq >= 0.4.1"]
hf-inference-endpoints = ["huggingface_hub >= 0.19.0"]
hf-inference-endpoints = ["huggingface_hub >= 0.22.0"]
hf-transformers = ["transformers >= 4.34.1", "torch >= 2.0.0"]
instructor = ["instructor >= 1.2.3"]
litellm = ["litellm >= 1.30.0"]
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from distilabel.llms.litellm import LiteLLM
from distilabel.llms.llamacpp import LlamaCppLLM
from distilabel.llms.mistral import MistralLLM
from distilabel.llms.mixins import CudaDevicePlacementMixin
from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.llms.moa import MixtureOfAgentsLLM
from distilabel.llms.ollama import OllamaLLM
from distilabel.llms.openai import OpenAILLM
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/llms/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class AzureOpenAILLM(OpenAILLM):
`None` if not set.

Icon:
`:simple-microsoftazure:`
`:material-microsoft-azure:`

Examples:

Expand Down
5 changes: 4 additions & 1 deletion src/distilabel/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,10 @@ async def _agenerate(
for _ in range(num_generations)
]
outputs = [outputs[0] for outputs in await asyncio.gather(*tasks)]
return list(grouper(outputs, n=num_generations, incomplete="ignore"))
return [
list(group)
for group in grouper(outputs, n=num_generations, incomplete="ignore")
]

def generate(
self,
Expand Down
463 changes: 311 additions & 152 deletions src/distilabel/llms/huggingface/inference_endpoints.py

Large diffs are not rendered by default.

35 changes: 27 additions & 8 deletions src/distilabel/llms/huggingface/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

from distilabel.llms.base import LLM
from distilabel.llms.chat_templates import CHATML_TEMPLATE
from distilabel.llms.mixins import CudaDevicePlacementMixin
from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import OutlinesStructuredOutputType, StandardInput
Expand All @@ -32,7 +33,7 @@
from distilabel.llms.typing import HiddenState


class TransformersLLM(LLM, CudaDevicePlacementMixin):
class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
"""Hugging Face `transformers` library LLM implementation using the text generation
pipeline.

Expand Down Expand Up @@ -64,6 +65,12 @@ class TransformersLLM(LLM, CudaDevicePlacementMixin):
local configuration will be used. Defaults to `None`.
structured_output: a dictionary containing the structured output configuration or if more
fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None.
use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
template. Defaults to `False`.
magpie_pre_query_template: the pre-query template to be applied to the prompt or
sent to the LLM to generate an instruction or a follow up user message. Valid
values are "llama3", "qwen2" or another pre-query template provided. Defaults
to `None`.

Icon:
`:hugging:`
Expand Down Expand Up @@ -157,14 +164,25 @@ def model_name(self) -> str:
return self.model

def prepare_input(self, input: "StandardInput") -> str:
"""Prepares the input by applying the chat template to the input, which is formatted
as an OpenAI conversation, and adding the generation prompt.
"""Prepares the input (applying the chat template and tokenization) for the provided
input.

Args:
input: the input list containing chat items.

Returns:
The prompt to send to the LLM.
"""
return self._pipeline.tokenizer.apply_chat_template( # type: ignore
input, # type: ignore
tokenize=False,
add_generation_prompt=True,
prompt: str = (
self._pipeline.tokenizer.apply_chat_template( # type: ignore
input, # type: ignore
tokenize=False,
add_generation_prompt=True,
)
if input
else ""
)
return super().apply_magpie_pre_query_template(prompt, input)

@validate_call
def generate( # type: ignore
Expand Down Expand Up @@ -209,6 +227,7 @@ def generate( # type: ignore
do_sample=do_sample,
num_return_sequences=num_generations,
prefix_allowed_tokens_fn=self._prefix_allowed_tokens_fn,
pad_token_id=self._pipeline.tokenizer.eos_token_id, # type: ignore
)
return [
[generation["generated_text"] for generation in output]
Expand Down
14 changes: 14 additions & 0 deletions src/distilabel/llms/mixins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

89 changes: 89 additions & 0 deletions src/distilabel/llms/mixins/magpie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Dict, Literal, Union

from pydantic import BaseModel, field_validator, model_validator
from typing_extensions import Self

if TYPE_CHECKING:
from distilabel.steps.tasks.typing import StandardInput

MagpieAvailablePreQueryTemplates = Literal["llama3", "qwen2"]
"""The available predefined pre-query templates."""

MAGPIE_PRE_QUERY_TEMPLATES: Dict[MagpieAvailablePreQueryTemplates, str] = {
"llama3": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n",
"qwen2": "<|im_start|>user\n",
}


class MagpieChatTemplateMixin(BaseModel, validate_assignment=True):
"""A simple mixin that adds the required logic to apply the pre-query template that
allows to an instruct fine-tuned LLM to generate user instructions as described in
the paper 'Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing'.

This mixin is meant to be used in combination with the [Magpie][distilabel.steps.tasks.magpie.base.Magpie]
task.

Attributes:
use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
template. Defaults to `False`.
magpie_pre_query_template: the pre-query template to be applied to the prompt or
sent to the LLM to generate an instruction or a follow up user message. Valid
values are "llama3", "qwen2" or another pre-query template provided. Defaults
to `None`.

References:
- [Magpie: Alignment Data Synthesis from Scratch by Prompting Aligned LLMs with Nothing](https://arxiv.org/abs/2406.08464)
"""

use_magpie_template: bool = False
magpie_pre_query_template: Union[MagpieAvailablePreQueryTemplates, str, None] = None

@field_validator("magpie_pre_query_template")
@classmethod
def magpie_pre_query_template_validator(cls, value: str) -> str:
"""Resolves the pre-query template alias if it exists, otherwise, returns the
value with no modification."""
if value in MAGPIE_PRE_QUERY_TEMPLATES:
return MAGPIE_PRE_QUERY_TEMPLATES[value]
return value

@model_validator(mode="after")
def use_magpie_template_validation(self) -> Self:
"""Checks that there is a pre-query template set if Magpie is going to be used."""
if self.use_magpie_template and self.magpie_pre_query_template is None:
raise ValueError(
f"Cannot set `use_magpie_template=True` if `magpie_pre_query_template` is"
f" `None`. To use Magpie with `{self.__class__.__name__}` you need to set"
f" the `magpie_pre_query_template` attribute."
)
return self

def apply_magpie_pre_query_template(
self, prompt: str, input: "StandardInput"
) -> str:
"""Applies the pre-query template to the prompt if Magpie is going to be used.

Args:
prompt: the prompt to which the pre-query template has to be applied.
input: the list with the chat items that were used to generate the prompt.

Returns:
The prompt with the pre-query template applied if needed.
"""
if not self.use_magpie_template or (input and input[-1]["role"] == "user"):
return prompt
return prompt + self.magpie_pre_query_template # type: ignore
43 changes: 31 additions & 12 deletions src/distilabel/llms/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@

from distilabel.llms.base import LLM
from distilabel.llms.chat_templates import CHATML_TEMPLATE
from distilabel.llms.mixins import CudaDevicePlacementMixin
from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType
Expand All @@ -39,11 +40,13 @@
from transformers import PreTrainedTokenizer
from vllm import LLM as _vLLM

from distilabel.steps.tasks.typing import StandardInput


SamplingParams = None


class vLLM(LLM, CudaDevicePlacementMixin):
class vLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
"""`vLLM` library LLM implementation.

Attributes:
Expand Down Expand Up @@ -75,6 +78,12 @@ class vLLM(LLM, CudaDevicePlacementMixin):
_tokenizer: the tokenizer instance used to format the prompt before passing it to
the `LLM`. This attribute is meant to be used internally and should not be
accessed directly. It will be set in the `load` method.
use_magpie_template: a flag used to enable/disable applying the Magpie pre-query
template. Defaults to `False`.
magpie_pre_query_template: the pre-query template to be applied to the prompt or
sent to the LLM to generate an instruction or a follow up user message. Valid
values are "llama3", "qwen2" or another pre-query template provided. Defaults
to `None`.

References:
- https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/llm.py
Expand Down Expand Up @@ -213,15 +222,26 @@ def model_name(self) -> str:
"""Returns the model name used for the LLM."""
return self.model

def prepare_input(self, input: "FormattedInput") -> str:
"""Prepares the input by applying the chat template to the input, which is formatted
as an OpenAI conversation, and adding the generation prompt.
def prepare_input(self, input: "StandardInput") -> str:
"""Prepares the input (applying the chat template and tokenization) for the provided
input.

Args:
input: the input list containing chat items.

Returns:
The prompt to send to the LLM.
"""
return self._tokenizer.apply_chat_template( # type: ignore
input, # type: ignore
tokenize=False,
add_generation_prompt=True, # type: ignore
prompt: str = (
self._tokenizer.apply_chat_template( # type: ignore
input, # type: ignore
tokenize=False,
add_generation_prompt=True, # type: ignore
)
if input
else ""
)
return super().apply_magpie_pre_query_template(prompt, input)

def _prepare_batches(
self, inputs: List[FormattedInput]
Expand Down Expand Up @@ -304,14 +324,13 @@ def generate( # type: ignore
if extra_sampling_params is None:
extra_sampling_params = {}
structured_output = None
needs_sorting = False

if isinstance(inputs[0], tuple):
prepared_batches, sorted_indices = self._prepare_batches(inputs)
needs_sorting = True
else:
# Simulate a batch without the structured output content
prepared_batches = [([self.prepare_input(input) for input in inputs], None)]
sorted_indices = None

# In case we have a single structured output for the dataset, we can
logits_processors = None
Expand Down Expand Up @@ -348,7 +367,7 @@ def generate( # type: ignore

# If logits_processor is set, we need to sort the outputs back to the original order
# (would be needed only if we have multiple structured outputs in the dataset)
if needs_sorting:
if sorted_indices is not None:
batched_outputs = _sort_batches(
batched_outputs, sorted_indices, num_generations=num_generations
)
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/pipeline/step_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from queue import Queue
from typing import Any, Dict, List, Optional, Union, cast

from distilabel.llms.mixins import CudaDevicePlacementMixin
from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.pipeline.batch import _Batch
from distilabel.pipeline.constants import LAST_BATCH_SENT_FLAG
from distilabel.pipeline.typing import StepLoadStatus
Expand Down
4 changes: 4 additions & 0 deletions src/distilabel/steps/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
from distilabel.steps.tasks.instruction_backtranslation import (
InstructionBacktranslation,
)
from distilabel.steps.tasks.magpie.base import Magpie
from distilabel.steps.tasks.magpie.generator import MagpieGenerator
from distilabel.steps.tasks.pair_rm import PairRM
from distilabel.steps.tasks.prometheus_eval import PrometheusEval
from distilabel.steps.tasks.quality_scorer import QualityScorer
Expand Down Expand Up @@ -64,6 +66,8 @@
"GenerateTextRetrievalData",
"MonolingualTripletGenerator",
"InstructionBacktranslation",
"Magpie",
"MagpieGenerator",
"PairRM",
"PrometheusEval",
"QualityScorer",
Expand Down
1 change: 1 addition & 0 deletions src/distilabel/steps/tasks/evol_instruct/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def process(self, offset: int = 0) -> "GeneratorStepOutput": # type: ignore
instructions = []
mutation_no = 0

# TODO: update to take into account `offset`
iter_no = 0
while len(instructions) < self.num_instructions:
prompts = self._apply_random_mutation(iter_no=iter_no)
Expand Down
14 changes: 14 additions & 0 deletions src/distilabel/steps/tasks/magpie/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Loading
Loading