diff --git a/.github/workflows/docs-pr-close.yml b/.github/workflows/docs-pr-close.yml
index 71f4e5ff93..61008bcee1 100644
--- a/.github/workflows/docs-pr-close.yml
+++ b/.github/workflows/docs-pr-close.yml
@@ -8,6 +8,10 @@ concurrency:
group: distilabel-docs
cancel-in-progress: false
+permissions:
+ contents: write
+ pull-requests: write
+
jobs:
cleanup:
runs-on: ubuntu-latest
diff --git a/.github/workflows/docs-pr.yml b/.github/workflows/docs-pr.yml
index 48c7236a58..ec963ccf98 100644
--- a/.github/workflows/docs-pr.yml
+++ b/.github/workflows/docs-pr.yml
@@ -10,6 +10,10 @@ concurrency:
group: distilabel-docs
cancel-in-progress: false
+permissions:
+ contents: write
+ pull-requests: write
+
jobs:
publish:
runs-on: ubuntu-latest
diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml
index dd59a5129d..93a17408e8 100644
--- a/.github/workflows/docs.yml
+++ b/.github/workflows/docs.yml
@@ -12,6 +12,10 @@ concurrency:
group: distilabel-docs
cancel-in-progress: false
+permissions:
+ contents: write
+ pull-requests: write
+
jobs:
publish:
runs-on: ubuntu-latest
diff --git a/README.md b/README.md
index 728d69c0b4..7a7dfc8d3d 100644
--- a/README.md
+++ b/README.md
@@ -118,7 +118,7 @@ pip install "distilabel[hf-inference-endpoints]" --upgrade
Then run:
```python
-from distilabel.llms import InferenceEndpointsLLM
+from distilabel.models import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromHub
from distilabel.steps.tasks import TextGeneration
diff --git a/docs/api/embedding/embedding_gallery.md b/docs/api/embedding/embedding_gallery.md
deleted file mode 100644
index 3eed3ab50e..0000000000
--- a/docs/api/embedding/embedding_gallery.md
+++ /dev/null
@@ -1,8 +0,0 @@
-# Embedding Gallery
-
-This section contains the existing [`Embeddings`][distilabel.embeddings] subclasses implemented in `distilabel`.
-
-::: distilabel.embeddings
- options:
- filters:
- - "!^Embeddings$"
\ No newline at end of file
diff --git a/docs/api/llm/index.md b/docs/api/llm/index.md
deleted file mode 100644
index fe58a65384..0000000000
--- a/docs/api/llm/index.md
+++ /dev/null
@@ -1,7 +0,0 @@
-# LLM
-
-This section contains the API reference for the `distilabel` LLMs, both for the [`LLM`][distilabel.llms.LLM] synchronous implementation, and for the [`AsyncLLM`][distilabel.llms.AsyncLLM] asynchronous one.
-
-For more information and examples on how to use existing LLMs or create custom ones, please refer to [Tutorial - LLM](../../sections/how_to_guides/basic/llm/index.md).
-
-::: distilabel.llms.base
diff --git a/docs/api/llm/llm_gallery.md b/docs/api/llm/llm_gallery.md
deleted file mode 100644
index ad0b1b75f0..0000000000
--- a/docs/api/llm/llm_gallery.md
+++ /dev/null
@@ -1,10 +0,0 @@
-# LLM Gallery
-
-This section contains the existing [`LLM`][distilabel.llms] subclasses implemented in `distilabel`.
-
-::: distilabel.llms
- options:
- filters:
- - "!^LLM$"
- - "!^AsyncLLM$"
- - "!typing"
\ No newline at end of file
diff --git a/docs/api/models/embedding/embedding_gallery.md b/docs/api/models/embedding/embedding_gallery.md
new file mode 100644
index 0000000000..3324caa304
--- /dev/null
+++ b/docs/api/models/embedding/embedding_gallery.md
@@ -0,0 +1,8 @@
+# Embedding Gallery
+
+This section contains the existing [`Embeddings`][distilabel.models.embeddings] subclasses implemented in `distilabel`.
+
+::: distilabel.models.embeddings
+ options:
+ filters:
+ - "!^Embeddings$"
\ No newline at end of file
diff --git a/docs/api/embedding/index.md b/docs/api/models/embedding/index.md
similarity index 83%
rename from docs/api/embedding/index.md
rename to docs/api/models/embedding/index.md
index 675593e183..fc1cfb0dc3 100644
--- a/docs/api/embedding/index.md
+++ b/docs/api/models/embedding/index.md
@@ -4,4 +4,4 @@ This section contains the API reference for the `distilabel` embeddings.
For more information on how the [`Embeddings`][distilabel.steps.tasks.Task] works and see some examples.
-::: distilabel.embeddings.base
\ No newline at end of file
+::: distilabel.models.embeddings.base
\ No newline at end of file
diff --git a/docs/api/models/llm/index.md b/docs/api/models/llm/index.md
new file mode 100644
index 0000000000..903329c22d
--- /dev/null
+++ b/docs/api/models/llm/index.md
@@ -0,0 +1,7 @@
+# LLM
+
+This section contains the API reference for the `distilabel` LLMs, both for the [`LLM`][distilabel.models.llms.LLM] synchronous implementation, and for the [`AsyncLLM`][distilabel.models.llms.AsyncLLM] asynchronous one.
+
+For more information and examples on how to use existing LLMs or create custom ones, please refer to [Tutorial - LLM](../../../sections/how_to_guides/basic/llm/index.md).
+
+::: distilabel.models.llms.base
diff --git a/docs/api/models/llm/llm_gallery.md b/docs/api/models/llm/llm_gallery.md
new file mode 100644
index 0000000000..e571d3fe29
--- /dev/null
+++ b/docs/api/models/llm/llm_gallery.md
@@ -0,0 +1,10 @@
+# LLM Gallery
+
+This section contains the existing [`LLM`][distilabel.models.llms] subclasses implemented in `distilabel`.
+
+::: distilabel.models.llms
+ options:
+ filters:
+ - "!^LLM$"
+ - "!^AsyncLLM$"
+ - "!typing"
\ No newline at end of file
diff --git a/docs/sections/getting_started/faq.md b/docs/sections/getting_started/faq.md
index 7a78126c46..6e6462a620 100644
--- a/docs/sections/getting_started/faq.md
+++ b/docs/sections/getting_started/faq.md
@@ -44,13 +44,13 @@ hide:
You can serve the LLM using a solution like TGI or vLLM, and then connect to it using an `AsyncLLM` client like `InferenceEndpointsLLM` or `OpenAILLM`. Please refer to [Serving LLMs guide](../how_to_guides/advanced/serving_an_llm_for_reuse.md) for more information.
??? faq "Can `distilabel` be used with [OpenAI Batch API](https://platform.openai.com/docs/guides/batch)?"
- Yes, `distilabel` is integrated with OpenAI Batch API via [OpenAILLM][distilabel.llms.openai.OpenAILLM]. Check [LLMs - Offline Batch Generation](../how_to_guides/basic/llm/index.md#offline-batch-generation) for a small example on how to use it and [Advanced - Offline Batch Generation](../how_to_guides/advanced/offline_batch_generation.md) for a more detailed guide.
+ Yes, `distilabel` is integrated with OpenAI Batch API via [OpenAILLM][distilabel.models.llms.openai.OpenAILLM]. Check [LLMs - Offline Batch Generation](../how_to_guides/basic/llm/index.md#offline-batch-generation) for a small example on how to use it and [Advanced - Offline Batch Generation](../how_to_guides/advanced/offline_batch_generation.md) for a more detailed guide.
-??? faq "Prevent overloads on [Free Serverless Endpoints][distilabel.llms.huggingface.InferenceEndpointsLLM]"
- When running a task using the [InferenceEndpointsLLM][distilabel.llms.huggingface.InferenceEndpointsLLM] with Free Serverless Endpoints, you may be facing some errors such as `Model is overloaded` if you let the batch size to the default (set at 50). To fix the issue, lower the value or even better set `input_batch_size=1` in your task. It may take a longer time to finish, but please remember this is a free service.
+??? faq "Prevent overloads on [Free Serverless Endpoints][distilabel.models.llms.huggingface.InferenceEndpointsLLM]"
+ When running a task using the [InferenceEndpointsLLM][distilabel.models.llms.huggingface.InferenceEndpointsLLM] with Free Serverless Endpoints, you may be facing some errors such as `Model is overloaded` if you let the batch size to the default (set at 50). To fix the issue, lower the value or even better set `input_batch_size=1` in your task. It may take a longer time to finish, but please remember this is a free service.
```python
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
from distilabel.steps import TextGeneration
TextGeneration(
diff --git a/docs/sections/getting_started/installation.md b/docs/sections/getting_started/installation.md
index 54e130b7fa..c11392e3f1 100644
--- a/docs/sections/getting_started/installation.md
+++ b/docs/sections/getting_started/installation.md
@@ -75,7 +75,7 @@ Additionally, as part of `distilabel` some extra dependencies are available, mai
## Recommendations / Notes
-The [`mistralai`](https://github.com/mistralai/client-python) dependency requires Python 3.9 or higher, so if you're willing to use the `distilabel.llms.MistralLLM` implementation, you will need to have Python 3.9 or higher.
+The [`mistralai`](https://github.com/mistralai/client-python) dependency requires Python 3.9 or higher, so if you're willing to use the `distilabel.models.llms.MistralLLM` implementation, you will need to have Python 3.9 or higher.
In some cases like [`transformers`](https://github.com/huggingface/transformers) and [`vllm`](https://github.com/vllm-project/vllm), the installation of [`flash-attn`](https://github.com/Dao-AILab/flash-attention) is recommended if you are using a GPU accelerator since it will speed up the inference process, but the installation needs to be done separately, as it's not included in the `distilabel` dependencies.
diff --git a/docs/sections/getting_started/quickstart.md b/docs/sections/getting_started/quickstart.md
index 7af9bca8f0..5a6a919ec1 100644
--- a/docs/sections/getting_started/quickstart.md
+++ b/docs/sections/getting_started/quickstart.md
@@ -30,12 +30,12 @@ pip install distilabel[hf-inference-endpoints] --upgrade
## Define a pipeline
-In this guide we will walk you through the process of creating a simple pipeline that uses the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class to generate text. The [`Pipeline`][distilabel.pipeline.Pipeline] will load a dataset that contains a column named `prompt` from the Hugging Face Hub via the step [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub] and then use the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class to generate text based on the dataset using the [`TextGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgeneration/) task.
+In this guide we will walk you through the process of creating a simple pipeline that uses the [`InferenceEndpointsLLM`][distilabel.models.llms.InferenceEndpointsLLM] class to generate text. The [`Pipeline`][distilabel.pipeline.Pipeline] will load a dataset that contains a column named `prompt` from the Hugging Face Hub via the step [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub] and then use the [`InferenceEndpointsLLM`][distilabel.models.llms.InferenceEndpointsLLM] class to generate text based on the dataset using the [`TextGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgeneration/) task.
> You can check the available models in the [Hugging Face Model Hub](https://huggingface.co/models?pipeline_tag=text-generation&sort=trending) and filter by `Inference status`.
```python
-from distilabel.llms import InferenceEndpointsLLM
+from distilabel.models import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromHub
from distilabel.steps.tasks import TextGeneration
@@ -85,9 +85,9 @@ if __name__ == "__main__":
3. We define a [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub] step named `load_dataset` that will load a dataset from the Hugging Face Hub, as provided via runtime parameters in the `pipeline.run` method below, but it can also be defined within the class instance via the arg `repo_id=...`. This step will produce output batches with the rows from the dataset, and the column `prompt` will be mapped to the `instruction` field.
-4. We define a [`TextGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgeneration/) task named `text_generation` that will generate text based on the `instruction` field from the dataset. This task will use the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class with the model `Meta-Llama-3.1-8B-Instruct`.
+4. We define a [`TextGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgeneration/) task named `text_generation` that will generate text based on the `instruction` field from the dataset. This task will use the [`InferenceEndpointsLLM`][distilabel.models.llms.InferenceEndpointsLLM] class with the model `Meta-Llama-3.1-8B-Instruct`.
-5. We define the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] class with the model `Meta-Llama-3.1-8B-Instruct` that will be used by the [`TextGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgeneration/) task. In this case, since the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] is used, we assume that the `HF_TOKEN` environment variable is set.
+5. We define the [`InferenceEndpointsLLM`][distilabel.models.llms.InferenceEndpointsLLM] class with the model `Meta-Llama-3.1-8B-Instruct` that will be used by the [`TextGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgeneration/) task. In this case, since the [`InferenceEndpointsLLM`][distilabel.models.llms.InferenceEndpointsLLM] is used, we assume that the `HF_TOKEN` environment variable is set.
6. Both `system_prompt` and `template` are optional fields. The `template` must be informed as a string following the [Jinja2](https://jinja.palletsprojects.com/en/3.1.x/templates/#synopsis) template format, and the fields that appear there ("instruction" in this case, which corresponds to the default) must be informed in the `columns` attribute. The component gallery for [`TextGeneration`](https://distilabel.argilla.io/dev/components-gallery/tasks/textgeneration/) has examples to get you started.
diff --git a/docs/sections/how_to_guides/advanced/argilla.md b/docs/sections/how_to_guides/advanced/argilla.md
index 2d8c047960..5e7c9e6d50 100644
--- a/docs/sections/how_to_guides/advanced/argilla.md
+++ b/docs/sections/how_to_guides/advanced/argilla.md
@@ -23,7 +23,7 @@ The dataset will be pushed with the following configuration:
The [`TextGenerationToArgilla`][distilabel.steps.TextGenerationToArgilla] step will only work as is if the [`Pipeline`][distilabel.pipeline.Pipeline] contains one or multiple [`TextGeneration`][distilabel.steps.tasks.TextGeneration] steps, or if the columns `instruction` and `generation` are available within the batch data. Otherwise, the variable `input_mappings` will need to be set so that either both or one of `instruction` and `generation` are mapped to one of the existing columns in the batch data.
```python
-from distilabel.llms import OpenAILLM
+from distilabel.models import OpenAILLM
from distilabel.steps import LoadDataFromDicts, TextGenerationToArgilla
from distilabel.steps.tasks import TextGeneration
@@ -74,7 +74,7 @@ The dataset will be pushed with the following configuration:
Additionally, if the [`Pipeline`][distilabel.pipeline.Pipeline] contains an [`UltraFeedback`][distilabel.steps.tasks.UltraFeedback] step, the `ratings` and `rationales` will also be available and be automatically injected as suggestions to the existing dataset.
```python
-from distilabel.llms import OpenAILLM
+from distilabel.models import OpenAILLM
from distilabel.steps import LoadDataFromDicts, PreferenceToArgilla
from distilabel.steps.tasks import TextGeneration
diff --git a/docs/sections/how_to_guides/advanced/assigning_resources_to_step.md b/docs/sections/how_to_guides/advanced/assigning_resources_to_step.md
index 9a2e02dc82..60e7bcae7d 100644
--- a/docs/sections/how_to_guides/advanced/assigning_resources_to_step.md
+++ b/docs/sections/how_to_guides/advanced/assigning_resources_to_step.md
@@ -4,7 +4,7 @@ When dealing with complex pipelines that get executed in a distributed environme
```python
from distilabel.pipeline import Pipeline
-from distilabel.llms import vLLM
+from distilabel.models import vLLM
from distilabel.steps import StepResources
from distilabel.steps.tasks import PrometheusEval
diff --git a/docs/sections/how_to_guides/advanced/offline_batch_generation.md b/docs/sections/how_to_guides/advanced/offline_batch_generation.md
index b45ad1d716..ddccd288ea 100644
--- a/docs/sections/how_to_guides/advanced/offline_batch_generation.md
+++ b/docs/sections/how_to_guides/advanced/offline_batch_generation.md
@@ -14,7 +14,7 @@ The [offline batch generation](../basic/llm/index.md#offline-batch-generation) i
## Example pipeline using `OpenAILLM` with offline batch generation
```python
-from distilabel.llms import OpenAILLM
+from distilabel.models import OpenAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromHub
from distilabel.steps.tasks import TextGeneration
diff --git a/docs/sections/how_to_guides/advanced/scaling_with_ray.md b/docs/sections/how_to_guides/advanced/scaling_with_ray.md
index be959c8b72..fa7ba9553a 100644
--- a/docs/sections/how_to_guides/advanced/scaling_with_ray.md
+++ b/docs/sections/how_to_guides/advanced/scaling_with_ray.md
@@ -41,7 +41,7 @@ pip install distilabel[ray]
For the purpose of explaining how to execute a pipeline with Ray, we'll use the following pipeline throughout the examples:
```python
-from distilabel.llms import vLLM
+from distilabel.models import vLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromHub
from distilabel.steps.tasks import TextGeneration
diff --git a/docs/sections/how_to_guides/advanced/serving_an_llm_for_reuse.md b/docs/sections/how_to_guides/advanced/serving_an_llm_for_reuse.md
index c015bd7a7e..f07ba1ebd3 100644
--- a/docs/sections/how_to_guides/advanced/serving_an_llm_for_reuse.md
+++ b/docs/sections/how_to_guides/advanced/serving_an_llm_for_reuse.md
@@ -21,7 +21,7 @@ docker run --gpus all --shm-size 1g -p 8080:80 -v $volume:/data \
And then we can use `InferenceEndpointsLLM` with `base_url=http://localhost:8080` (pointing to our `TGI` local deployment):
```python
-from distilabel.llms import InferenceEndpointsLLM
+from distilabel.models import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts
from distilabel.steps.tasks import TextGeneration, UltraFeedback
@@ -66,7 +66,7 @@ docker run --gpus all \
And then we can use `OpenAILLM` with `base_url=http://localhost:8000` (pointing to our `vLLM` local deployment):
```python
-from distilabel.llms import OpenAILLM
+from distilabel.models import OpenAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts
from distilabel.steps.tasks import TextGeneration, UltraFeedback
diff --git a/docs/sections/how_to_guides/advanced/structured_generation.md b/docs/sections/how_to_guides/advanced/structured_generation.md
index 6f907951c1..6d6ed034eb 100644
--- a/docs/sections/how_to_guides/advanced/structured_generation.md
+++ b/docs/sections/how_to_guides/advanced/structured_generation.md
@@ -1,12 +1,12 @@
# Structured data generation
-`Distilabel` has integrations with relevant libraries to generate structured text i.e. to guide the [`LLM`][distilabel.llms.LLM] towards the generation of structured outputs following a JSON schema, a regex, etc.
+`Distilabel` has integrations with relevant libraries to generate structured text i.e. to guide the [`LLM`][distilabel.models.llms.LLM] towards the generation of structured outputs following a JSON schema, a regex, etc.
## Outlines
-`Distilabel` integrates [`outlines`](https://outlines-dev.github.io/outlines/welcome/) within some [`LLM`][distilabel.llms.LLM] subclasses. At the moment, the following LLMs integrated with `outlines` are supported in `distilabel`: [`TransformersLLM`][distilabel.llms.TransformersLLM], [`vLLM`][distilabel.llms.vLLM] or [`LlamaCppLLM`][distilabel.llms.LlamaCppLLM], so that anyone can generate structured outputs in the form of *JSON* or a parseable *regex*.
+`Distilabel` integrates [`outlines`](https://outlines-dev.github.io/outlines/welcome/) within some [`LLM`][distilabel.models.llms.LLM] subclasses. At the moment, the following LLMs integrated with `outlines` are supported in `distilabel`: [`TransformersLLM`][distilabel.models.llms.TransformersLLM], [`vLLM`][distilabel.models.llms.vLLM] or [`LlamaCppLLM`][distilabel.models.llms.LlamaCppLLM], so that anyone can generate structured outputs in the form of *JSON* or a parseable *regex*.
-The [`LLM`][distilabel.llms.LLM] has an argument named `structured_output`[^1] that determines how we can generate structured outputs with it, let's see an example using [`LlamaCppLLM`][distilabel.llms.LlamaCppLLM].
+The [`LLM`][distilabel.models.llms.LLM] has an argument named `structured_output`[^1] that determines how we can generate structured outputs with it, let's see an example using [`LlamaCppLLM`][distilabel.models.llms.LlamaCppLLM].
!!! Note
@@ -36,7 +36,7 @@ class User(BaseModel):
And then we provide that schema to the `structured_output` argument of the LLM.
```python
-from distilabel.llms import LlamaCppLLM
+from distilabel.models import LlamaCppLLM
llm = LlamaCppLLM(
model_path="./openhermes-2.5-mistral-7b.Q4_K_M.gguf" # (1)
@@ -129,7 +129,7 @@ These were some simple examples, but one can see the options this opens.
## Instructor
-For other LLM providers behind APIs, there's no direct way of accessing the internal logit processor like `outlines` does, but thanks to [`instructor`](https://python.useinstructor.com/) we can generate structured output from LLM providers based on `pydantic.BaseModel` objects. We have integrated `instructor` to deal with the [`AsyncLLM`][distilabel.llms.AsyncLLM].
+For other LLM providers behind APIs, there's no direct way of accessing the internal logit processor like `outlines` does, but thanks to [`instructor`](https://python.useinstructor.com/) we can generate structured output from LLM providers based on `pydantic.BaseModel` objects. We have integrated `instructor` to deal with the [`AsyncLLM`][distilabel.models.llms.AsyncLLM].
!!! Note
For `instructor` integration to work you may need to install the corresponding dependencies:
@@ -159,7 +159,7 @@ And then we provide that schema to the `structured_output` argument of the LLM:
In this example we are using *Meta Llama 3.1 8B Instruct*, keep in mind not all the models support structured outputs.
```python
-from distilabel.llms import MistralLLM
+from distilabel.models import MistralLLM
llm = InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-8B-Instruct",
@@ -204,7 +204,7 @@ Contrary to what we have via `outlines`, JSON mode will not guarantee the output
Other than the reference to generating JSON, to ensure the model generates parseable JSON we can pass the argument `response_format="json"`[^3]:
```python
-from distilabel.llms import OpenAILLM
+from distilabel.models import OpenAILLM
llm = OpenAILLM(model="gpt4-turbo", api_key="api.key")
llm.generate(..., response_format="json")
```
diff --git a/docs/sections/how_to_guides/basic/llm/index.md b/docs/sections/how_to_guides/basic/llm/index.md
index f9dec754ae..d5d5a37368 100644
--- a/docs/sections/how_to_guides/basic/llm/index.md
+++ b/docs/sections/how_to_guides/basic/llm/index.md
@@ -5,7 +5,7 @@
LLM subclasses are designed to be used within a [Task][distilabel.steps.tasks.Task], but they can also be used standalone.
```python
-from distilabel.llms import InferenceEndpointsLLM
+from distilabel.models import InferenceEndpointsLLM
llm = InferenceEndpointsLLM(model="meta-llama/Meta-Llama-3.1-70B-Instruct")
llm.load()
@@ -23,12 +23,12 @@ llm.generate_outputs(
### Offline Batch Generation
-By default, all `LLM`s will generate text in a synchronous manner i.e. send inputs using `generate_outputs` method that will get blocked until outputs are generated. There are some `LLM`s (such as [OpenAILLM][distilabel.llms.openai.OpenAILLM]) that implements what we denote as _offline batch generation_, which allows to send the inputs to the LLM-as-a-service which will generate the outputs asynchronously and give us a job id that we can use later to check the status and retrieve the generated outputs when they are ready. LLM-as-a-service platforms offers this feature as a way to save costs in exchange of waiting for the outputs to be generated.
+By default, all `LLM`s will generate text in a synchronous manner i.e. send inputs using `generate_outputs` method that will get blocked until outputs are generated. There are some `LLM`s (such as [OpenAILLM][distilabel.models.llms.openai.OpenAILLM]) that implements what we denote as _offline batch generation_, which allows to send the inputs to the LLM-as-a-service which will generate the outputs asynchronously and give us a job id that we can use later to check the status and retrieve the generated outputs when they are ready. LLM-as-a-service platforms offers this feature as a way to save costs in exchange of waiting for the outputs to be generated.
To use this feature in `distilabel` the only thing we need to do is to set the `use_offline_batch_generation` attribute to `True` when creating the `LLM` instance:
```python
-from distilabel.llms import OpenAILLM
+from distilabel.models import OpenAILLM
llm = OpenAILLM(
model="gpt-4o",
@@ -67,7 +67,7 @@ llm.generate_outputs( # (4)
The `offline_batch_generation_block_until_done` attribute can be used to block the `generate_outputs` method until the outputs are ready polling the platform the specified amount of seconds.
```python
-from distilabel.llms import OpenAILLM
+from distilabel.models import OpenAILLM
llm = OpenAILLM(
model="gpt-4o",
@@ -89,7 +89,7 @@ llm.generate_outputs(
Pass the LLM as an argument to the [`Task`][distilabel.steps.tasks.Task], and the task will handle the rest.
```python
-from distilabel.llms import OpenAILLM
+from distilabel.models import OpenAILLM
from distilabel.steps.tasks import TextGeneration
llm = OpenAILLM(model="gpt-4")
@@ -110,7 +110,7 @@ LLMs can have runtime parameters, such as `generation_kwargs`, provided via the
```python
from distilabel.pipeline import Pipeline
-from distilabel.llms import OpenAILLM
+from distilabel.models import OpenAILLM
from distilabel.steps import LoadDataFromDicts
from distilabel.steps.tasks import TextGeneration
@@ -137,7 +137,7 @@ if __name__ == "__main__":
## Creating custom LLMs
-To create custom LLMs, subclass either [`LLM`][distilabel.llms.LLM] for synchronous or [`AsyncLLM`][distilabel.llms.AsyncLLM] for asynchronous LLMs. Implement the following methods:
+To create custom LLMs, subclass either [`LLM`][distilabel.models.llms.LLM] for synchronous or [`AsyncLLM`][distilabel.models.llms.AsyncLLM] for asynchronous LLMs. Implement the following methods:
* `model_name`: A property containing the model's name.
@@ -155,9 +155,9 @@ To create custom LLMs, subclass either [`LLM`][distilabel.llms.LLM] for synchron
from pydantic import validate_call
- from distilabel.llms import LLM
- from distilabel.llms.typing import GenerateOutput, HiddenState
- from distilabel.steps.tasks.typing import ChatType
+ from distilabel.models import LLM
+ from distilabel.typing import GenerateOutput, HiddenState
+ from distilabel.typing import ChatType
class CustomLLM(LLM):
@property
@@ -180,9 +180,9 @@ To create custom LLMs, subclass either [`LLM`][distilabel.llms.LLM] for synchron
from pydantic import validate_call
- from distilabel.llms import AsyncLLM
- from distilabel.llms.typing import GenerateOutput, HiddenState
- from distilabel.steps.tasks.typing import ChatType
+ from distilabel.models import AsyncLLM
+ from distilabel.typing import GenerateOutput, HiddenState
+ from distilabel.typing import ChatType
class CustomAsyncLLM(AsyncLLM):
@property
diff --git a/docs/sections/how_to_guides/basic/pipeline/index.md b/docs/sections/how_to_guides/basic/pipeline/index.md
index f592082191..27be4dae9d 100644
--- a/docs/sections/how_to_guides/basic/pipeline/index.md
+++ b/docs/sections/how_to_guides/basic/pipeline/index.md
@@ -85,7 +85,7 @@ Next, we will use `prompt` column from the dataset obtained through `LoadDataFro
The order of the execution of the steps will be determined by the connections of the steps. In this case, the `TextGeneration` tasks will be executed after the `LoadDataFromHub` step.
```python
-from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM
+from distilabel.models import MistralLLM, OpenAILLM, VertexAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromHub
from distilabel.steps.tasks import TextGeneration
@@ -110,7 +110,7 @@ For each row of the dataset, the `TextGeneration` task will generate a text base
In this case, the `GroupColumns` tasks will be executed after all `TextGeneration` steps.
```python
-from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM
+from distilabel.models import MistralLLM, OpenAILLM, VertexAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps import GroupColumns, LoadDataFromHub
from distilabel.steps.tasks import TextGeneration
@@ -143,7 +143,7 @@ Besides the `Step.connect` method: `step1.connect(step2)`, there's an alternativ
Each call to `step1.connect(step2)` has been exchanged by `step1 >> step2` within the loop.
```python
- from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM
+ from distilabel.models import MistralLLM, OpenAILLM, VertexAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps import GroupColumns, LoadDataFromHub
from distilabel.steps.tasks import TextGeneration
@@ -171,7 +171,7 @@ Besides the `Step.connect` method: `step1.connect(step2)`, there's an alternativ
Each task is first appended to a list, and then all the calls to connections are done in a single call.
```python
- from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM
+ from distilabel.models import MistralLLM, OpenAILLM, VertexAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps import GroupColumns, LoadDataFromHub
from distilabel.steps.tasks import TextGeneration
@@ -206,7 +206,7 @@ Let's update the example above to route the batches loaded by the `LoadDataFromH
```python
import random
-from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM
+from distilabel.models import MistralLLM, OpenAILLM, VertexAILLM
from distilabel.pipeline import Pipeline, routing_batch_function
from distilabel.steps import GroupColumns, LoadDataFromHub
from distilabel.steps.tasks import TextGeneration
@@ -338,7 +338,7 @@ Note that in most cases if you don't need the extra flexibility the [`GeneratorS
```python hl_lines="11-14 33 38"
import random
-from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM
+from distilabel.models import MistralLLM, OpenAILLM, VertexAILLM
from distilabel.pipeline import Pipeline, routing_batch_function
from distilabel.steps import GroupColumns
from distilabel.steps.tasks import TextGeneration
@@ -403,7 +403,7 @@ if __name__ == "__main__":
Memory issues can arise when processing large datasets or when using large models. To avoid this, we can use the `input_batch_size` argument of individual tasks. `TextGeneration` task will receive 5 dictionaries, while the `LoadDataFromHub` step will send 10 dictionaries per batch:
```python
-from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM
+from distilabel.models import MistralLLM, OpenAILLM, VertexAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps import GroupColumns, LoadDataFromHub
from distilabel.steps.tasks import TextGeneration
@@ -489,7 +489,7 @@ To sum up, here is the full code of the pipeline we have created in this section
??? Code
```python
- from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM
+ from distilabel.models import MistralLLM, OpenAILLM, VertexAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps import GroupColumns, LoadDataFromHub
from distilabel.steps.tasks import TextGeneration
diff --git a/docs/sections/how_to_guides/basic/step/index.md b/docs/sections/how_to_guides/basic/step/index.md
index 18388b8f4a..d03a6b2149 100644
--- a/docs/sections/how_to_guides/basic/step/index.md
+++ b/docs/sections/how_to_guides/basic/step/index.md
@@ -71,7 +71,7 @@ There are two special types of [`Step`][distilabel.steps.Step] in `distilabel`:
* [`GlobalStep`][distilabel.steps.GlobalStep]: is a step with the standard interface i.e. receives inputs and generates outputs, but it processes all the data at once, and often is the final step in the [`Pipeline`][distilabel.pipeline.Pipeline]. The fact that a [`GlobalStep`][distilabel.steps.GlobalStep] requires the previous steps to finish before being able to start. More information: [Components - Step - GlobalStep](global_step.md).
-* [`Task`][distilabel.steps.tasks.Task], is essentially the same as a default [`Step`][distilabel.steps.Step], but it relies on an [`LLM`][distilabel.llms.LLM] as an attribute, and the `process` method will be in charge of calling that LLM. More information: [Components - Task](../task/index.md).
+* [`Task`][distilabel.steps.tasks.Task], is essentially the same as a default [`Step`][distilabel.steps.Step], but it relies on an [`LLM`][distilabel.models.llms.LLM] as an attribute, and the `process` method will be in charge of calling that LLM. More information: [Components - Task](../task/index.md).
## Defining custom Steps
diff --git a/docs/sections/how_to_guides/basic/task/generator_task.md b/docs/sections/how_to_guides/basic/task/generator_task.md
index 613d8deb17..6fbb3d742e 100644
--- a/docs/sections/how_to_guides/basic/task/generator_task.md
+++ b/docs/sections/how_to_guides/basic/task/generator_task.md
@@ -68,11 +68,11 @@ next(task.process())
We can define a custom generator task by creating a new subclass of the [`GeneratorTask`][distilabel.steps.tasks.Task] and defining the following:
-- `process`: is a method that generates the data based on the [`LLM`][distilabel.llms.LLM] and the `instruction` provided within the class instance, and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in `outputs`. Note that the `inputs` argument is not allowed in this function since this is a [`GeneratorTask`][distilabel.steps.tasks.GeneratorTask]. The signature only expects the `offset` argument, which is used to keep track of the current iteration in the generator.
+- `process`: is a method that generates the data based on the [`LLM`][distilabel.models.llms.LLM] and the `instruction` provided within the class instance, and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in `outputs`. Note that the `inputs` argument is not allowed in this function since this is a [`GeneratorTask`][distilabel.steps.tasks.GeneratorTask]. The signature only expects the `offset` argument, which is used to keep track of the current iteration in the generator.
- `outputs`: is a property that returns a list of strings with the names of the output fields, this property should always include `model_name` as one of the outputs since that's automatically injected from the LLM.
-- `format_output`: is a method that receives the output from the [`LLM`][distilabel.llms.LLM] and optionally also the input data (which may be useful to build the output in some scenarios), and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in `outputs`. Note that there's no need to include the `model_name` in the output.
+- `format_output`: is a method that receives the output from the [`LLM`][distilabel.models.llms.LLM] and optionally also the input data (which may be useful to build the output in some scenarios), and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in `outputs`. Note that there's no need to include the `model_name` in the output.
```python
from typing import Any, Dict, List, Union
diff --git a/docs/sections/how_to_guides/basic/task/index.md b/docs/sections/how_to_guides/basic/task/index.md
index 7aa2049f4b..7f1d8260e0 100644
--- a/docs/sections/how_to_guides/basic/task/index.md
+++ b/docs/sections/how_to_guides/basic/task/index.md
@@ -2,12 +2,12 @@
## Working with Tasks
-The [`Task`][distilabel.steps.tasks.Task] is a special kind of [`Step`][distilabel.steps.Step] that includes the [`LLM`][distilabel.llms.LLM] as a mandatory argument. As with a [`Step`][distilabel.steps.Step], it is normally used within a [`Pipeline`][distilabel.pipeline.Pipeline] but can also be used standalone.
+The [`Task`][distilabel.steps.tasks.Task] is a special kind of [`Step`][distilabel.steps.Step] that includes the [`LLM`][distilabel.models.llms.LLM] as a mandatory argument. As with a [`Step`][distilabel.steps.Step], it is normally used within a [`Pipeline`][distilabel.pipeline.Pipeline] but can also be used standalone.
For example, the most basic task is the [`TextGeneration`][distilabel.steps.tasks.TextGeneration] task, which generates text based on a given instruction.
```python
-from distilabel.llms import InferenceEndpointsLLM
+from distilabel.models import InferenceEndpointsLLM
from distilabel.steps.tasks import TextGeneration
task = TextGeneration(
@@ -66,7 +66,7 @@ The `Tasks` include a handy method to show what the prompt formatted for an `LLM
```python
from distilabel.steps.tasks import UltraFeedback
-from distilabel.llms.huggingface import InferenceEndpointsLLM
+from distilabel.models import InferenceEndpointsLLM
uf = UltraFeedback(
llm=InferenceEndpointsLLM(
@@ -95,8 +95,8 @@ uf.print(
In case you don't want to load an LLM to render the template, you can create a dummy one like the ones we could use for testing.
```python
- from distilabel.llms.base import LLM
- from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
+ from distilabel.models import LLM
+ from distilabel.models.mixins import MagpieChatTemplateMixin
class DummyLLM(AsyncLLM, MagpieChatTemplateMixin):
structured_output: Any = None
@@ -131,7 +131,7 @@ uf.print(
All the `Task`s have a `num_generations` attribute that allows defining the number of generations that we want to have per input. We can update the example above to generate 3 completions per input:
```python
-from distilabel.llms import InferenceEndpointsLLM
+from distilabel.models import InferenceEndpointsLLM
from distilabel.steps.tasks import TextGeneration
task = TextGeneration(
@@ -170,7 +170,7 @@ next(task.process([{"instruction": "What's the capital of Spain?"}]))
In addition, we might want to group the generations in a single output row as maybe one downstream step expects a single row with multiple generations. We can achieve this by setting the `group_generations` attribute to `True`:
```python
-from distilabel.llms import InferenceEndpointsLLM
+from distilabel.models import InferenceEndpointsLLM
from distilabel.steps.tasks import TextGeneration
task = TextGeneration(
@@ -209,37 +209,65 @@ We can define a custom step by creating a new subclass of the [`Task`][distilabe
- `outputs`: is a property that returns a list of strings with the names of the output fields or a dictionary in which the keys are the names of the columns and the values are boolean indicating whether the column is required or not. This property should always include `model_name` as one of the outputs since that's automatically injected from the LLM.
-- `format_output`: is a method that receives the output from the [`LLM`][distilabel.llms.LLM] and optionally also the input data (which may be useful to build the output in some scenarios), and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in `outputs`. Note that there's no need to include the `model_name` in the output.
+- `format_output`: is a method that receives the output from the [`LLM`][distilabel.models.llms.LLM] and optionally also the input data (which may be useful to build the output in some scenarios), and returns a dictionary with the output data formatted as needed i.e. with the values for the columns in `outputs`. Note that there's no need to include the `model_name` in the output.
-```python
-from typing import Any, Dict, List, Union, TYPE_CHECKING
+=== "Inherit from `Task`"
+
+ When using the `Task` class inheritance method for creating a custom task, we can also optionally override the `Task.process` method to define a more complex processing logic involving an `LLM`, as the default one just calls the `LLM.generate` method once previously formatting the input and subsequently formatting the output. For example, [EvolInstruct][distilabel.steps.tasks.EvolInstruct] task overrides this method to call the `LLM.generate` multiple times (one for each evolution).
-from distilabel.steps.tasks.base import Task
+ ```python
+ from typing import Any, Dict, List, Union, TYPE_CHECKING
-if TYPE_CHECKING:
- from distilabel.steps.typing import StepColumns
- from distilabel.steps.tasks.typing import ChatType
+ from distilabel.steps.tasks import Task
+ if TYPE_CHECKING:
+ from distilabel.steps.typing import StepColumns
+ from distilabel.steps.tasks.typing import ChatType
-class MyCustomTask(Task):
- @property
- def inputs(self) -> "StepColumns":
- return ["input_field"]
- def format_input(self, input: Dict[str, Any]) -> "ChatType":
- return [
- {
- "role": "user",
- "content": input["input_field"],
- },
- ]
+ class MyCustomTask(Task):
+ @property
+ def inputs(self) -> "StepColumns":
+ return ["input_field"]
- @property
- def outputs(self) -> "StepColumns":
- return ["output_field", "model_name"]
+ def format_input(self, input: Dict[str, Any]) -> "ChatType":
+ return [
+ {
+ "role": "user",
+ "content": input["input_field"],
+ },
+ ]
- def format_output(
- self, output: Union[str, None], input: Dict[str, Any]
- ) -> Dict[str, Any]:
+ @property
+ def outputs(self) -> "StepColumns":
+ return ["output_field", "model_name"]
+
+ def format_output(
+ self, output: Union[str, None], input: Dict[str, Any]
+ ) -> Dict[str, Any]:
+ return {"output_field": output}
+ ```
+
+=== "Using the `@task` decorator"
+
+ If your task just needs a system prompt, a user message template and a way to format the output given by the `LLM`, then you can use the `@task` decorator to avoid writing too much boilerplate code.
+
+ ```python
+ from typing import Any, Dict, Union
+ from distilabel.steps.tasks import task
+
+
+ @task(inputs=["input_field"], outputs=["output_field"])
+ def MyCustomTask(output: Union[str, None], input: Union[Dict[str, Any], None] = None) -> Dict[str, Any]:
+ """
+ ---
+ system_prompt: |
+ My custom system prompt
+
+ user_message_template: |
+ My custom user message template: {input_field}
+ ---
+ """
+ # Format the `LLM` output here
return {"output_field": output}
-```
+ ```
diff --git a/docs/sections/pipeline_samples/examples/fine_personas_social_network.md b/docs/sections/pipeline_samples/examples/fine_personas_social_network.md
index 52df495fc4..dd60208cc5 100644
--- a/docs/sections/pipeline_samples/examples/fine_personas_social_network.md
+++ b/docs/sections/pipeline_samples/examples/fine_personas_social_network.md
@@ -130,7 +130,7 @@ With our data in hand, we're ready to explore the capabilities of our SocialAI t
While this model has become something of a go-to choice recently, it's worth noting that experimenting with a variety of models could yield even more interesting results:
```python
-from distilabel.llms import InferenceEndpointsLLM
+from distilabel.models import InferenceEndpointsLLM
llm = InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
diff --git a/docs/sections/pipeline_samples/examples/llama_cpp_with_outlines.md b/docs/sections/pipeline_samples/examples/llama_cpp_with_outlines.md
index 9ff0bdff8f..02ed31feed 100644
--- a/docs/sections/pipeline_samples/examples/llama_cpp_with_outlines.md
+++ b/docs/sections/pipeline_samples/examples/llama_cpp_with_outlines.md
@@ -5,11 +5,11 @@ hide: toc
Generate RPG characters following a `pydantic.BaseModel` with `outlines` in `distilabel`.
-This script makes use of [`LlamaCppLLM`][distilabel.llms.llamacpp.LlamaCppLLM] and the structured output capabilities thanks to [`outlines`](https://outlines-dev.github.io/outlines/welcome/) to generate RPG characters that adhere to a JSON schema.
+This script makes use of [`LlamaCppLLM`][distilabel.models.llms.llamacpp.LlamaCppLLM] and the structured output capabilities thanks to [`outlines`](https://outlines-dev.github.io/outlines/welcome/) to generate RPG characters that adhere to a JSON schema.
![Arena Hard](../../../assets/pipelines/knowledge_graphs.png)
-It makes use of a local model which can be downloaded using curl (explained in the script itself), and can be exchanged with other `LLMs` like [`vLLM`][distilabel.llms.vllm.vLLM].
+It makes use of a local model which can be downloaded using curl (explained in the script itself), and can be exchanged with other `LLMs` like [`vLLM`][distilabel.models.llms.vllm.vLLM].
??? Run
diff --git a/docs/sections/pipeline_samples/examples/mistralai_with_instructor.md b/docs/sections/pipeline_samples/examples/mistralai_with_instructor.md
index 7e081ab222..aab0cedf65 100644
--- a/docs/sections/pipeline_samples/examples/mistralai_with_instructor.md
+++ b/docs/sections/pipeline_samples/examples/mistralai_with_instructor.md
@@ -5,7 +5,7 @@ hide: toc
Answer instructions with knowledge graphs defined as `pydantic.BaseModel` objects using `instructor` in `distilabel`.
-This script makes use of [`MistralLLM`][distilabel.llms.mistral.MistralLLM] and the structured output capabilities thanks to [`instructor`](https://python.useinstructor.com/) to generate knowledge graphs from complex topics.
+This script makes use of [`MistralLLM`][distilabel.models.llms.mistral.MistralLLM] and the structured output capabilities thanks to [`instructor`](https://python.useinstructor.com/) to generate knowledge graphs from complex topics.
![Knowledge graph figure](../../../assets/pipelines/knowledge_graphs.png)
diff --git a/docs/sections/pipeline_samples/index.md b/docs/sections/pipeline_samples/index.md
index 6cf718faab..0cff031018 100644
--- a/docs/sections/pipeline_samples/index.md
+++ b/docs/sections/pipeline_samples/index.md
@@ -37,6 +37,14 @@ hide: toc
[:octicons-arrow-right-24: Tutorial](tutorials/GenerateSentencePair.ipynb)
+- __Generate text classification data__
+
+ ---
+
+ Learn about how synthetic data generation for text classification can help address data imbalance or scarcity.
+
+ [:octicons-arrow-right-24: Tutorial](tutorials/generate_textcat_dataset.ipynb)
+
## Paper Implementations
diff --git a/docs/sections/pipeline_samples/papers/clair.md b/docs/sections/pipeline_samples/papers/clair.md
index 8c0887460b..a246df12b8 100644
--- a/docs/sections/pipeline_samples/papers/clair.md
+++ b/docs/sections/pipeline_samples/papers/clair.md
@@ -43,7 +43,7 @@ from datasets import load_dataset
from distilabel.pipeline import Pipeline
from distilabel.steps.tasks import CLAIR
-from distilabel.llms import InferenceEndpointsLLM
+from distilabel.models import InferenceEndpointsLLM
def transform_ultrafeedback(example: Dict[str, Any]) -> Dict[str, Any]:
diff --git a/docs/sections/pipeline_samples/papers/deita.md b/docs/sections/pipeline_samples/papers/deita.md
index b9d3e9eea6..46ab4fc18d 100644
--- a/docs/sections/pipeline_samples/papers/deita.md
+++ b/docs/sections/pipeline_samples/papers/deita.md
@@ -38,7 +38,7 @@ pip install pynvml huggingface_hub argilla
Import distilabel:
```python
-from distilabel.llms import TransformersLLM, OpenAILLM
+from distilabel.models import TransformersLLM, OpenAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps import ConversationTemplate, DeitaFiltering, ExpandColumns, LoadDataFromHub
from distilabel.steps.tasks import ComplexityScorer, EvolInstruct, EvolQuality, GenerateEmbeddings, QualityScorer
diff --git a/docs/sections/pipeline_samples/papers/instruction_backtranslation.md b/docs/sections/pipeline_samples/papers/instruction_backtranslation.md
index b3a6b20d68..11725d41fd 100644
--- a/docs/sections/pipeline_samples/papers/instruction_backtranslation.md
+++ b/docs/sections/pipeline_samples/papers/instruction_backtranslation.md
@@ -28,22 +28,22 @@ To replicate Self Alignment with Instruction Backtranslation one will need to in
pip install "distilabel[hf-inference-endpoints,openai]>=1.0.0"
```
-And since we will be using [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM] (installed via the extra `hf-inference-endpoints`) we will need deploy those in advance either locally or in the Hugging Face Hub (alternatively also the serverless endpoints can be used, but most of the times the inference times are slower, and there's a limited quota to use those as those are free) and set both the `HF_TOKEN` (to use the [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM]) and the `OPENAI_API_KEY` environment variable value (to use the [`OpenAILLM`][distilabel.llms.OpenAILLM]).
+And since we will be using [`InferenceEndpointsLLM`][distilabel.models.InferenceEndpointsLLM] (installed via the extra `hf-inference-endpoints`) we will need deploy those in advance either locally or in the Hugging Face Hub (alternatively also the serverless endpoints can be used, but most of the times the inference times are slower, and there's a limited quota to use those as those are free) and set both the `HF_TOKEN` (to use the [`InferenceEndpointsLLM`][distilabel.models.InferenceEndpointsLLM]) and the `OPENAI_API_KEY` environment variable value (to use the [`OpenAILLM`][distilabel.models.OpenAILLM]).
#### Building blocks
- [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub]: Generator Step to load a dataset from the Hugging Face Hub.
- [`TextGeneration`][distilabel.steps.tasks.TextGeneration]: Task to generate responses for a given instruction using an LLM.
- - [`InferenceEndpointsLLM`][distilabel.llms.InferenceEndpointsLLM]: LLM that runs a model from an Inference Endpoint in the Hugging Face Hub.
+ - [`InferenceEndpointsLLM`][distilabel.models.InferenceEndpointsLLM]: LLM that runs a model from an Inference Endpoint in the Hugging Face Hub.
- [`InstructionBacktranslation`][distilabel.steps.tasks.InstructionBacktranslation]: Task that generates a score and a reason for a response for a given instruction using the Self Alignment with Instruction Backtranslation prompt.
- - [`OpenAILLM`][distilabel.llms.OpenAILLM]: LLM that loads a model from OpenAI.
+ - [`OpenAILLM`][distilabel.models.OpenAILLM]: LLM that loads a model from OpenAI.
#### Code
As mentioned before, we will put the previously mentioned building blocks together to replicate Self Alignment with Instruction Backtranslation.
```python
-from distilabel.llms import InferenceEndpointsLLM, OpenAILLM
+from distilabel.models import InferenceEndpointsLLM, OpenAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromHub, KeepColumns
from distilabel.steps.tasks import InstructionBacktranslation, TextGeneration
diff --git a/docs/sections/pipeline_samples/papers/prometheus.md b/docs/sections/pipeline_samples/papers/prometheus.md
index c8a3fb16c5..c9c0e6f76d 100644
--- a/docs/sections/pipeline_samples/papers/prometheus.md
+++ b/docs/sections/pipeline_samples/papers/prometheus.md
@@ -49,7 +49,7 @@ pip install flash-attn --no-build-isolation
- [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub]: [`GeneratorStep`][distilabel.steps.GeneratorStep] to load a dataset from the Hugging Face Hub.
- [`PrometheusEval`][distilabel.steps.tasks.PrometheusEval]: [`Task`][distilabel.steps.tasks.Task] that assesses the quality of a response for a given instruction using any of the Prometheus 2 models.
- - [`vLLM`][distilabel.llms.vLLM]: [`LLM`][distilabel.llms.LLM] that loads a model from the Hugging Face Hub via [vllm-project/vllm](https://github.com/vllm-project/vllm).
+ - [`vLLM`][distilabel.models.vLLM]: [`LLM`][distilabel.models.LLM] that loads a model from the Hugging Face Hub via [vllm-project/vllm](https://github.com/vllm-project/vllm).
!!! NOTE
Since the Prometheus 2 models use a slightly different chat template than [`mistralai/Mistral-7B-Instruct-v0.2`](https://hf.co/mistralai/Mistral-7B-Instruct-v0.2), we need to set the `chat_template` parameter to `[INST] {{ messages[0]['content'] }}\n{{ messages[1]['content'] }}[/INST]` so as to properly format the input for Prometheus 2.
@@ -61,7 +61,7 @@ pip install flash-attn --no-build-isolation
As mentioned before, we will put the previously mentioned building blocks together to see how Prometheus 2 can be used via `distilabel`.
```python
-from distilabel.llms import vLLM
+from distilabel.models import vLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import KeepColumns, LoadDataFromHub
from distilabel.steps.tasks import PrometheusEval
diff --git a/docs/sections/pipeline_samples/papers/ultrafeedback.md b/docs/sections/pipeline_samples/papers/ultrafeedback.md
index 83acc9f335..3e1d1822f3 100644
--- a/docs/sections/pipeline_samples/papers/ultrafeedback.md
+++ b/docs/sections/pipeline_samples/papers/ultrafeedback.md
@@ -29,10 +29,10 @@ And since we will be using `vllm` we will need to use a VM with at least 6 NVIDI
- [`LoadDataFromHub`][distilabel.steps.LoadDataFromHub]: Generator Step to load a dataset from the Hugging Face Hub.
- [`sample_n_steps`][distilabel.pipeline.sample_n_steps]: Function to create a `routing_batch_function` that samples `n` downstream steps for each batch generated by the upstream step. This is the key to replicate the LLM pooling mechanism described in the paper.
- [`TextGeneration`][distilabel.steps.tasks.TextGeneration]: Task to generate responses for a given instruction using an LLM.
- - [`vLLM`][distilabel.llms.vLLM]: LLM that loads a model from the Hugging Face Hub using `vllm`.
+ - [`vLLM`][distilabel.models.vLLM]: LLM that loads a model from the Hugging Face Hub using `vllm`.
- [`GroupColumns`][distilabel.steps.GroupColumns]: Task that combines multiple columns into a single one i.e. from string to list of strings. Useful when there are multiple parallel steps that are connected to the same node.
- [`UltraFeedback`][distilabel.steps.tasks.UltraFeedback]: Task that generates ratings for the responses of a given instruction using the UltraFeedback prompt.
- - [`OpenAILLM`][distilabel.llms.OpenAILLM]: LLM that loads a model from OpenAI.
+ - [`OpenAILLM`][distilabel.models.OpenAILLM]: LLM that loads a model from OpenAI.
- [`KeepColumns`][distilabel.steps.KeepColumns]: Task to keep the desired columns while removing the not needed ones, as well as defining the order for those.
- (optional) [`PreferenceToArgilla`][distilabel.steps.PreferenceToArgilla]: Task to optionally push the generated dataset to Argilla to do some further analysis and human annotation.
@@ -41,7 +41,7 @@ And since we will be using `vllm` we will need to use a VM with at least 6 NVIDI
As mentioned before, we will put the previously mentioned building blocks together to replicate UltraFeedback.
```python
-from distilabel.llms import OpenAILLM, vLLM
+from distilabel.models import OpenAILLM, vLLM
from distilabel.pipeline import Pipeline, sample_n_steps
from distilabel.steps import (
GroupColumns,
diff --git a/docs/sections/pipeline_samples/tutorials/GenerateSentencePair.ipynb b/docs/sections/pipeline_samples/tutorials/GenerateSentencePair.ipynb
index 0779a53eb9..3fad88f9ab 100644
--- a/docs/sections/pipeline_samples/tutorials/GenerateSentencePair.ipynb
+++ b/docs/sections/pipeline_samples/tutorials/GenerateSentencePair.ipynb
@@ -59,7 +59,7 @@
"metadata": {},
"outputs": [],
"source": [
- "from distilabel.llms.huggingface import InferenceEndpointsLLM\n",
+ "from distilabel.models import InferenceEndpointsLLM\n",
"from distilabel.pipeline import Pipeline\n",
"from distilabel.steps.tasks import GenerateSentencePair\n",
"from distilabel.steps import LoadDataFromHub\n",
diff --git a/docs/sections/pipeline_samples/tutorials/clean_existing_dataset.ipynb b/docs/sections/pipeline_samples/tutorials/clean_existing_dataset.ipynb
index de1e9fd264..7b75f7fcaa 100644
--- a/docs/sections/pipeline_samples/tutorials/clean_existing_dataset.ipynb
+++ b/docs/sections/pipeline_samples/tutorials/clean_existing_dataset.ipynb
@@ -69,7 +69,7 @@
"\n",
"from datasets import load_dataset\n",
"\n",
- "from distilabel.llms import InferenceEndpointsLLM\n",
+ "from distilabel.models import InferenceEndpointsLLM\n",
"from distilabel.pipeline import Pipeline\n",
"from distilabel.steps import (\n",
" KeepColumns,\n",
diff --git a/docs/sections/pipeline_samples/tutorials/generate_preference_dataset.ipynb b/docs/sections/pipeline_samples/tutorials/generate_preference_dataset.ipynb
index a81e8051ad..d350416895 100644
--- a/docs/sections/pipeline_samples/tutorials/generate_preference_dataset.ipynb
+++ b/docs/sections/pipeline_samples/tutorials/generate_preference_dataset.ipynb
@@ -65,7 +65,7 @@
"metadata": {},
"outputs": [],
"source": [
- "from distilabel.llms import InferenceEndpointsLLM\n",
+ "from distilabel.models import InferenceEndpointsLLM\n",
"from distilabel.pipeline import Pipeline\n",
"from distilabel.steps import (\n",
" LoadDataFromHub,\n",
diff --git a/docs/sections/pipeline_samples/tutorials/generate_textcat_dataset.ipynb b/docs/sections/pipeline_samples/tutorials/generate_textcat_dataset.ipynb
new file mode 100644
index 0000000000..fd66bca0dd
--- /dev/null
+++ b/docs/sections/pipeline_samples/tutorials/generate_textcat_dataset.ipynb
@@ -0,0 +1,981 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Generate synthetic text classification data"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- **Goal**: Generate synthetic text classification data to augment an imbalanced and limited dataset for training a topic classifier. In addition, generate new data for training a fact-based versus opinion-based classifier to add a new label.\n",
+ "- **Libraries**: [argilla](https://github.com/argilla-io/argilla), [hf-inference-endpoints](https://github.com/huggingface/huggingface_hub), [SetFit](https://github.com/huggingface/setfit)\n",
+ "- **Components**: [LoadDataFromDicts](https://distilabel.argilla.io/latest/components-gallery/steps/loaddatafromdicts/), [EmbeddingTaskGenerator](https://distilabel.argilla.io/latest/components-gallery/tasks/embeddingtaskgenerator/), [GenerateTextClassificationData](https://distilabel.argilla.io/latest/components-gallery/tasks/generatetextclassificationdata/)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Getting started\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Install the dependencies\n",
+ "\n",
+ "To complete this tutorial, you need to install the distilabel SDK and a few third-party libraries via pip. We will be using **the free but rate-limited Hugging Face serverless Inference API** for this tutorial, so we need to install this as an extra distilabel dependency. You can install them by running the following command:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install \"distilabel[hf-inference-endpoints]\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install \"transformers~=4.40\" \"torch~=2.0\" \"setfit~=1.0\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's make the required imports:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import random\n",
+ "from collections import Counter\n",
+ "\n",
+ "from datasets import load_dataset, Dataset\n",
+ "from distilabel.models import InferenceEndpointsLLM\n",
+ "from distilabel.pipeline import Pipeline\n",
+ "from distilabel.steps import LoadDataFromDicts\n",
+ "from distilabel.steps.tasks import (\n",
+ " GenerateTextClassificationData,\n",
+ ")\n",
+ "from setfit import SetFitModel, Trainer, sample_dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "You'll need an `HF_TOKEN` to use the HF Inference Endpoints. Log in to use it directly within this notebook.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from huggingface_hub import login\n",
+ "\n",
+ "login(token=os.getenv(\"HF_TOKEN\"), add_to_git_credential=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### (optional) Deploy Argilla\n",
+ "\n",
+ "You can skip this step or replace it with any other data evaluation tool, but the quality of your model will suffer from a lack of data quality, so we do recommend looking at your data. If you already deployed Argilla, you can skip this step. Otherwise, you can quickly deploy Argilla following [this guide](https://docs.argilla.io/latest/getting_started/quickstart/).\n",
+ "\n",
+ "Along with that, you will need to install Argilla as a distilabel extra.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install \"distilabel[argilla, hf-inference-endpoints]\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## The dataset\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We will use the [`fancyzhx/ag_news`](https://huggingface.co/datasets/fancyzhx/ag_news) dataset from the Hugging Face Hub as our original data source. To simulate a real-world scenario with imbalanced and limited data, we will load only 20 samples from this dataset.\n",
+ "\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "hf_dataset = load_dataset(\"fancyzhx/ag_news\", split=\"train[-20:]\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we can retrieve the available labels in the dataset and examine the current data distribution."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{0: 'World', 1: 'Sports', 2: 'Business', 3: 'Sci/Tech'}\n",
+ "Counter({0: 12, 1: 6, 2: 2})\n"
+ ]
+ }
+ ],
+ "source": [
+ "labels_topic = hf_dataset.features[\"label\"].names\n",
+ "id2str = {i: labels_topic[i] for i in range(len(labels_topic))}\n",
+ "print(id2str)\n",
+ "print(Counter(hf_dataset[\"label\"]))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "As observed, the dataset is imbalanced, with most samples falling under the `World` category, while the `Sci/Tech` category is entirely missing. Moreover, there are insufficient samples to effectively train a topic classification model.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We will also define the labels for the new classification task."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "labels_fact_opinion = [\"Fact-based\", \"Opinion-based\"]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Define the text classification task\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "To generate the data we will use the `GenerateTextClassificationData` task. This task will use as input classification tasks and we can define the language, difficulty and clarity required for the generated data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[{'role': 'user', 'content': 'You have been assigned a text classification task: Classify the news article as fact-based or opinion-based\\n\\nYour mission is to write one text classification example for this task in JSON format. The JSON object must contain the following keys:\\n - \"input_text\": a string, the input text specified by the classification task.\\n - \"label\": a string, the correct label of the input text.\\n - \"misleading_label\": a string, an incorrect label that is related to the task.\\n\\nPlease adhere to the following guidelines:\\n - The \"input_text\" should be diverse in expression.\\n - The \"misleading_label\" must be a valid label for the given task, but not as appropriate as the \"label\" for the \"input_text\".\\n - The values for all fields should be in English.\\n - Avoid including the values of the \"label\" and \"misleading_label\" fields in the \"input_text\", that would make the task too easy.\\n - The \"input_text\" is clear and requires college level education to comprehend.\\n\\nYour output must always be a JSON object only, do not explain yourself or output anything else. Be creative!'}]\n"
+ ]
+ }
+ ],
+ "source": [
+ "task = GenerateTextClassificationData(\n",
+ " language=\"English\",\n",
+ " difficulty=\"college\",\n",
+ " clarity=\"clear\",\n",
+ " num_generations=1,\n",
+ " llm=InferenceEndpointsLLM(\n",
+ " model_id=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
+ " tokenizer_id=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
+ " generation_kwargs={\"max_new_tokens\": 512, \"temperature\": 0.4},\n",
+ " ),\n",
+ " input_batch_size=5,\n",
+ ")\n",
+ "task.load()\n",
+ "result = next(\n",
+ " task.process([{\"task\": \"Classify the news article as fact-based or opinion-based\"}])\n",
+ ")\n",
+ "print(result[0][\"distilabel_metadata\"][\"raw_input_generate_text_classification_data_0\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "For our use case, we only need to generate data for two tasks: a topic classification task and a fact versus opinion classification task. Therefore, we will define the tasks accordingly. As we will be using an smaller model for generation, we will select 2 random labels for each topic classification task and change the order for the fact versus opinion classification task ensuring more diversity in the generated data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "task_templates = [\n",
+ " \"Determine the news article as {}\",\n",
+ " \"Classify news article as {}\",\n",
+ " \"Identify the news article as {}\",\n",
+ " \"Categorize the news article as {}\",\n",
+ " \"Label the news article using {}\",\n",
+ " \"Annotate the news article based on {}\",\n",
+ " \"Determine the theme of a news article from {}\",\n",
+ " \"Recognize the topic of the news article as {}\",\n",
+ "]\n",
+ "\n",
+ "classification_tasks = [\n",
+ " {\"task\": action.format(\" or \".join(random.sample(labels_topic, 2)))}\n",
+ " for action in task_templates for _ in range(4)\n",
+ "] + [\n",
+ " {\"task\": action.format(\" or \".join(random.sample(labels_fact_opinion, 2)))}\n",
+ " for action in task_templates\n",
+ "]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Run the pipeline\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, it's time to define and run the pipeline. As mentioned, we will load the written tasks and feed them into the `GenerateTextClassificationData` task. For our use case, we will be using `Meta-Llama-3.1-8B-Instruct` via the `InferenceEndpointsLLM`, with different degrees of difficulty and clarity.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "difficulties = [\"college\", \"high school\", \"PhD\"]\n",
+ "clarity = [\"clear\", \"understandable with some effort\", \"ambiguous\"]\n",
+ "\n",
+ "with Pipeline(\"texcat-generation-pipeline\") as pipeline:\n",
+ "\n",
+ " tasks_generator = LoadDataFromDicts(data=classification_tasks)\n",
+ "\n",
+ " generate_data = []\n",
+ " for difficulty in difficulties:\n",
+ " for clarity_level in clarity:\n",
+ " task = GenerateTextClassificationData(\n",
+ " language=\"English\",\n",
+ " difficulty=difficulty,\n",
+ " clarity=clarity_level,\n",
+ " num_generations=2,\n",
+ " llm=InferenceEndpointsLLM(\n",
+ " model_id=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
+ " tokenizer_id=\"meta-llama/Meta-Llama-3.1-8B-Instruct\",\n",
+ " generation_kwargs={\"max_new_tokens\": 512, \"temperature\": 0.7},\n",
+ " ),\n",
+ " input_batch_size=5,\n",
+ " )\n",
+ " generate_data.append(task)\n",
+ "\n",
+ " for task in generate_data:\n",
+ " tasks_generator.connect(task)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's now run the pipeline and generate the synthetic data.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "distiset = pipeline.run()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'task': 'Determine the news article as Business or World',\n",
+ " 'input_text': \"The recent decision by the European Central Bank to raise interest rates will likely have a significant impact on the eurozone's economic growth, with some analysts predicting a 0.5% contraction in GDP due to the increased borrowing costs. The move is seen as a measure to combat inflation, which has been rising steadily over the past year.\",\n",
+ " 'label': 'Business',\n",
+ " 'misleading_label': 'World',\n",
+ " 'distilabel_metadata': {'raw_output_generate_text_classification_data_0': '{\\n \"input_text\": \"The recent decision by the European Central Bank to raise interest rates will likely have a significant impact on the eurozone\\'s economic growth, with some analysts predicting a 0.5% contraction in GDP due to the increased borrowing costs. The move is seen as a measure to combat inflation, which has been rising steadily over the past year.\",\\n \"label\": \"Business\",\\n \"misleading_label\": \"World\"\\n}'},\n",
+ " 'model_name': 'meta-llama/Meta-Llama-3.1-8B-Instruct'}"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "distiset[\"generate_text_classification_data_0\"][\"train\"][0]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "You can push the dataset to the Hub for sharing with the community and [embed it to explore the data](https://huggingface.co/docs/hub/datasets-viewer-embed).\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "distiset.push_to_hub(\"[your-owner-name]/example-texcat-generation-dataset\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "By examining the distiset distribution, we can confirm that it includes at least the 8 required samples for each label to train our classification models with SetFit."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Counter({'Sci/Tech': 275,\n",
+ " 'Business': 130,\n",
+ " 'World': 86,\n",
+ " 'Fact-based': 86,\n",
+ " 'Sports': 64,\n",
+ " 'Opinion-based': 54,\n",
+ " None: 20,\n",
+ " 'Opinion Based': 1,\n",
+ " 'News/Opinion': 1,\n",
+ " 'Science': 1,\n",
+ " 'Environment': 1,\n",
+ " 'Opinion': 1})"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "all_labels = [\n",
+ " entry[\"label\"]\n",
+ " for dataset_name in distiset\n",
+ " for entry in distiset[dataset_name][\"train\"]\n",
+ "]\n",
+ "\n",
+ "Counter(all_labels)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We will create two datasets with the required labels and data for our use cases."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def extract_rows(distiset, labels):\n",
+ " return [\n",
+ " {\n",
+ " \"text\": entry[\"input_text\"],\n",
+ " \"label\": entry[\"label\"],\n",
+ " \"id\": i\n",
+ " }\n",
+ " for dataset_name in distiset\n",
+ " for i, entry in enumerate(distiset[dataset_name][\"train\"])\n",
+ " if entry[\"label\"] in labels\n",
+ " ]\n",
+ "\n",
+ "data_topic = extract_rows(distiset, labels_topic)\n",
+ "data_fact_opinion = extract_rows(distiset, labels_fact_opinion)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## (Optional) Evaluate with Argilla\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "!!! note \"Get started in Argilla\"\n",
+ " If you are not familiar with Argilla, we recommend taking a look at the [Argilla quickstart docs](https://docs.argilla.io/latest/getting_started/quickstart/). Alternatively, you can use your Hugging Face account to login to the [Argilla demo Space](https://argilla-argilla-template-space.hf.space).\n",
+ "\n",
+ "To get the most out of our data, we will use Argilla. First, we need to connect to the Argilla instance.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import argilla as rg\n",
+ "\n",
+ "# Replace api_url with your url if using Docker\n",
+ "# Replace api_key with your API key under \"My Settings\" in the UI\n",
+ "# Uncomment the last line and set your HF_TOKEN if your space is private\n",
+ "client = rg.Argilla(\n",
+ " api_url=\"https://[your-owner-name]-[your_space_name].hf.space\",\n",
+ " api_key=\"[your-api-key]\",\n",
+ " # headers={\"Authorization\": f\"Bearer {HF_TOKEN}\"}\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We will create a `Dataset` for each task, with an input `TextField` for the text classification text and a `LabelQuestion` to ensure the generated labels are correct.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def create_texcat_dataset(dataset_name, labels):\n",
+ " settings = rg.Settings(\n",
+ " fields=[rg.TextField(\"text\")],\n",
+ " questions=[\n",
+ " rg.LabelQuestion(\n",
+ " name=\"label\",\n",
+ " title=\"Classify the texts according to the following labels\",\n",
+ " labels=labels,\n",
+ " ),\n",
+ " ],\n",
+ " )\n",
+ " return rg.Dataset(name=dataset_name, settings=settings).create()\n",
+ "\n",
+ "\n",
+ "rg_dataset_topic = create_texcat_dataset(\"topic-classification\", labels_topic)\n",
+ "rg_dataset_fact_opinion = create_texcat_dataset(\n",
+ " \"fact-opinion-classification\", labels_fact_opinion\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we can upload the generated data to Argilla and evaluate it. We will use the generated labels as suggestions.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "rg_dataset_topic.records.log(data_topic)\n",
+ "rg_dataset_fact_opinion.records.log(data_fact_opinion)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, we can start the annotation process. Just open the dataset in the Argilla UI and start annotating the records. If the suggestions are correct, you can just click on `Submit`. Otherwise, you can select the correct label.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "!!! note\n",
+ " Check this [how-to guide](https://docs.argilla.io/latest/how_to_guides/annotate/) to know more about annotating in the UI.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Once, you get the annotations, let's continue by retrieving the data from Argilla and format it as a dataset with the required data.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "rg_dataset_topic = client.datasets(\"topic-classification\")\n",
+ "rg_dataset_fact_opinion = client.datasets(\"fact-opinion-classification\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "status_filter = rg.Query(filter=rg.Filter((\"response.status\", \"==\", \"submitted\")))\n",
+ "\n",
+ "submitted_topic = rg_dataset_topic.records(status_filter).to_list(flatten=True)\n",
+ "submitted_fact_opinion = rg_dataset_fact_opinion.records(status_filter).to_list(\n",
+ " flatten=True\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def format_submitted(submitted):\n",
+ " return [\n",
+ " {\n",
+ " \"text\": r[\"text\"],\n",
+ " \"label\": r[\"label.responses\"][0],\n",
+ " \"id\": i,\n",
+ " }\n",
+ " for i, r in enumerate(submitted)\n",
+ " ]\n",
+ "\n",
+ "data_topic = format_submitted(submitted_topic)\n",
+ "data_fact_opinion = format_submitted(submitted_fact_opinion)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Train your models\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In our case, we will fine-tune using SetFit. However, you can select the one that best fits your requirements.\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Formatting the data\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The next step will be to format the data to be compatible with SetFit. In the case of the topic classification, we will need to combine the synthetic data with the original data.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "hf_topic = hf_dataset.to_list()\n",
+ "num = len(data_topic)\n",
+ "\n",
+ "data_topic.extend(\n",
+ " [\n",
+ " {\n",
+ " \"text\": r[\"text\"],\n",
+ " \"label\": id2str[r[\"label\"]],\n",
+ " \"id\": num + i,\n",
+ " }\n",
+ " for i, r in enumerate(hf_topic)\n",
+ " ]\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "If we check the data distribution now, we can see that we have enough samples for each label to train our models.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Counter({'Sci/Tech': 275, 'Business': 132, 'World': 98, 'Sports': 70})"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "labels = [record[\"label\"] for record in data_topic]\n",
+ "Counter(labels)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Counter({'Fact-based': 86, 'Opinion-based': 54})"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "labels = [record[\"label\"] for record in data_fact_opinion]\n",
+ "Counter(labels)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Now, let's create our training and validation datasets. The training dataset will gather 8 samples by label. In this case, the validation datasets will contain the remaining samples not included in the training datasets.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def sample_and_split(dataset, label_column, num_samples):\n",
+ " train_dataset = sample_dataset(\n",
+ " dataset, label_column=label_column, num_samples=num_samples\n",
+ " )\n",
+ " eval_dataset = dataset.filter(lambda x: x[\"id\"] not in set(train_dataset[\"id\"]))\n",
+ " return train_dataset, eval_dataset\n",
+ "\n",
+ "\n",
+ "dataset_topic_full = Dataset.from_list(data_topic)\n",
+ "dataset_fact_opinion_full = Dataset.from_list(data_fact_opinion)\n",
+ "\n",
+ "train_dataset_topic, eval_dataset_topic = sample_and_split(\n",
+ " dataset_topic_full, \"label\", 8\n",
+ ")\n",
+ "train_dataset_fact_opinion, eval_dataset_fact_opinion = sample_and_split(\n",
+ " dataset_fact_opinion_full, \"label\", 8\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### The actual training\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's train our models for each task! We will use [TaylorAI/bge-micro-v2](https://huggingface.co/TaylorAI/bge-micro-v2), available in the Hugging Face Hub. You can check the [MTEB leaderboard](https://huggingface.co/spaces/mteb/leaderboard) to select the best model for your use case."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 126,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def train_model(model_name, dataset, eval_dataset):\n",
+ " model = SetFitModel.from_pretrained(model_name)\n",
+ "\n",
+ " trainer = Trainer(\n",
+ " model=model,\n",
+ " train_dataset=dataset,\n",
+ " )\n",
+ " trainer.train()\n",
+ " metrics = trainer.evaluate(eval_dataset)\n",
+ " print(metrics)\n",
+ "\n",
+ " return model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 125,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "***** Running training *****\n",
+ " Num unique pairs = 768\n",
+ " Batch size = 16\n",
+ " Num epochs = 1\n",
+ " Total optimization steps = 48\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'embedding_loss': 0.1873, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.02}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "***** Running evaluation *****\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'train_runtime': 4.9767, 'train_samples_per_second': 154.318, 'train_steps_per_second': 9.645, 'epoch': 1.0}\n",
+ "{'accuracy': 0.8333333333333334}\n"
+ ]
+ }
+ ],
+ "source": [
+ "model_topic = train_model(\n",
+ " model_name=\"TaylorAI/bge-micro-v2\",\n",
+ " dataset=train_dataset_topic,\n",
+ " eval_dataset=eval_dataset_topic,\n",
+ ")\n",
+ "model_topic.save_pretrained(\"topic_classification_model\")\n",
+ "model_topic = SetFitModel.from_pretrained(\"topic_classification_model\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 128,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "***** Running training *****\n",
+ " Num unique pairs = 144\n",
+ " Batch size = 16\n",
+ " Num epochs = 1\n",
+ " Total optimization steps = 9\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'embedding_loss': 0.2985, 'learning_rate': 2e-05, 'epoch': 0.11}\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "***** Running evaluation *****\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'train_runtime': 0.8327, 'train_samples_per_second': 172.931, 'train_steps_per_second': 10.808, 'epoch': 1.0}\n",
+ "{'accuracy': 0.9090909090909091}\n"
+ ]
+ }
+ ],
+ "source": [
+ "model_fact_opinion = train_model(\n",
+ " model_name=\"TaylorAI/bge-micro-v2\",\n",
+ " dataset=train_dataset_fact_opinion,\n",
+ " eval_dataset=eval_dataset_fact_opinion,\n",
+ ")\n",
+ "model_fact_opinion.save_pretrained(\"fact_opinion_classification_model\")\n",
+ "model_fact_opinion = SetFitModel.from_pretrained(\"fact_opinion_classification_model\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Voilà ! The models are now trained and ready to be used. You can start making predictions to check the model's performance and add the new label. Optionally, you can continue using distilabel to generate additional data or Argilla to verify the quality of the predictions."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 129,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def predict(model, input, labels):\n",
+ " model.labels = labels\n",
+ " prediction = model.predict([input])\n",
+ " return prediction[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 130,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'Sci/Tech'"
+ ]
+ },
+ "execution_count": 130,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "predict(\n",
+ " model_topic, \"The new iPhone is expected to be released next month.\", labels_topic\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 131,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'Opinion-based'"
+ ]
+ },
+ "execution_count": 131,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "predict(\n",
+ " model_fact_opinion,\n",
+ " \"The new iPhone is expected to be released next month.\",\n",
+ " labels_fact_opinion,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Conclusions\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this tutorial, we showcased the detailed steps to build a pipeline for generating text classification data using distilabel. You can customize this pipeline for your own use cases and share your datasets with the community through the Hugging Face Hub.\n",
+ "\n",
+ "We defined two text classification tasks—a topic classification task and a fact versus opinion classification task—and generated new data using various models via the serverless Hugging Face Inference API. Then, we curated the generated data with Argilla. Finally, we trained the models with SetFit using both the original and synthetic data."
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "distilabel-tutorials",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/examples/arena_hard.py b/examples/arena_hard.py
index b193bc2347..f8a8571e02 100644
--- a/examples/arena_hard.py
+++ b/examples/arena_hard.py
@@ -331,7 +331,7 @@ def process(self, inputs: StepInput) -> StepOutput: # type: ignore
if __name__ == "__main__":
import json
- from distilabel.llms import InferenceEndpointsLLM, OpenAILLM
+ from distilabel.models import InferenceEndpointsLLM, OpenAILLM
from distilabel.pipeline import Pipeline
from distilabel.steps import (
GroupColumns,
diff --git a/examples/deepseek_prover.py b/examples/deepseek_prover.py
index 07b0509646..08d32ba1bf 100644
--- a/examples/deepseek_prover.py
+++ b/examples/deepseek_prover.py
@@ -21,7 +21,7 @@
from pydantic import PrivateAttr
from typing_extensions import override
-from distilabel.llms import InferenceEndpointsLLM
+from distilabel.models import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromHub
from distilabel.steps.tasks.base import Task
@@ -68,7 +68,7 @@ class DeepSeekProverAutoFormalization(Task):
```python
from distilabel.steps.tasks import DeepSeekProverAutoFormalization
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
prover_autoformal = DeepSeekProverAutoFormalization(
@@ -104,7 +104,7 @@ class DeepSeekProverAutoFormalization(Task):
```python
from distilabel.steps.tasks import DeepSeekProverAutoFormalization
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# You can gain inspiration from the following examples to create your own few-shot examples:
# https://github.com/yangky11/miniF2F-lean4/blob/main/MiniF2F/Valid.lean
@@ -246,7 +246,7 @@ class DeepSeekProverScorer(Task):
```python
from distilabel.steps.tasks import DeepSeekProverScorer
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
prover_scorer = DeepSeekProverAutoFormalization(
diff --git a/examples/finepersonas_social_ai.py b/examples/finepersonas_social_ai.py
index 8c4f9afc73..8a6e743eb5 100644
--- a/examples/finepersonas_social_ai.py
+++ b/examples/finepersonas_social_ai.py
@@ -16,7 +16,7 @@
from datasets import load_dataset
-from distilabel.llms import InferenceEndpointsLLM
+from distilabel.models import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import FormatTextGenerationSFT, LoadDataFromDicts
from distilabel.steps.tasks import TextGeneration
diff --git a/examples/pipeline_apigen.py b/examples/pipeline_apigen.py
index e63e16e39e..21da0784b7 100644
--- a/examples/pipeline_apigen.py
+++ b/examples/pipeline_apigen.py
@@ -16,7 +16,7 @@
from datasets import load_dataset
-from distilabel.llms import InferenceEndpointsLLM
+from distilabel.models import InferenceEndpointsLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import CombineOutputs, DataSampler, LoadDataFromDicts
from distilabel.steps.tasks import (
diff --git a/examples/structured_generation_with_instructor.py b/examples/structured_generation_with_instructor.py
index 0808e56cac..c71170ff7d 100644
--- a/examples/structured_generation_with_instructor.py
+++ b/examples/structured_generation_with_instructor.py
@@ -16,7 +16,7 @@
from pydantic import BaseModel, Field
-from distilabel.llms import MistralLLM
+from distilabel.models import MistralLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts
from distilabel.steps.tasks import TextGeneration
diff --git a/examples/structured_generation_with_outlines.py b/examples/structured_generation_with_outlines.py
index b92cb6082f..a0834ad3e9 100644
--- a/examples/structured_generation_with_outlines.py
+++ b/examples/structured_generation_with_outlines.py
@@ -18,7 +18,7 @@
from pydantic import BaseModel, StringConstraints, conint
from typing_extensions import Annotated
-from distilabel.llms import LlamaCppLLM
+from distilabel.models import LlamaCppLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts
from distilabel.steps.tasks import TextGeneration
diff --git a/mkdocs.yml b/mkdocs.yml
index 69aaeed275..19b8e8a63e 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -203,6 +203,7 @@ nav:
- Generate a preference dataset: "sections/pipeline_samples/tutorials/generate_preference_dataset.ipynb"
- Clean an existing preference dataset: "sections/pipeline_samples/tutorials/clean_existing_dataset.ipynb"
- Synthetic data generation for fine-tuning custom retrieval and reranking models: "sections/pipeline_samples/tutorials/GenerateSentencePair.ipynb"
+ - Generate synthetic text classification data: "sections/pipeline_samples/tutorials/generate_textcat_dataset.ipynb"
- Papers:
- DeepSeek Prover: "sections/pipeline_samples/papers/deepseek_prover.md"
- DEITA: "sections/pipeline_samples/papers/deita.md"
@@ -235,11 +236,11 @@ nav:
- Task Gallery: "api/task/task_gallery.md"
- Typing: "api/task/typing.md"
- LLM:
- - "api/llm/index.md"
- - LLM Gallery: "api/llm/llm_gallery.md"
+ - "api/models/llm/index.md"
+ - LLM Gallery: "api/models/llm/llm_gallery.md"
- Embedding:
- - "api/embedding/index.md"
- - Embedding Gallery: "api/embedding/embedding_gallery.md"
+ - "api/models/embedding/index.md"
+ - Embedding Gallery: "api/models/embedding/embedding_gallery.md"
- Pipeline:
- "api/pipeline/index.md"
- Routing Batch Function: "api/pipeline/routing_batch_function.md"
diff --git a/src/distilabel/__init__.py b/src/distilabel/__init__.py
index f6ca72cd10..47628af331 100644
--- a/src/distilabel/__init__.py
+++ b/src/distilabel/__init__.py
@@ -14,6 +14,6 @@
from rich import traceback as rich_traceback
-__version__ = "1.4.1"
+__version__ = "1.5.0"
rich_traceback.install(show_locals=True)
diff --git a/src/distilabel/embeddings.py b/src/distilabel/embeddings.py
new file mode 100644
index 0000000000..aa470e5b4d
--- /dev/null
+++ b/src/distilabel/embeddings.py
@@ -0,0 +1,36 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# ruff: noqa: E402
+
+import warnings
+
+deprecation_message = (
+ "Importing from 'distilabel.embeddings' is deprecated and will be removed in a version 1.7.0. "
+ "Import from 'distilabel.models' instead."
+)
+
+warnings.warn(deprecation_message, DeprecationWarning, stacklevel=2)
+
+from distilabel.models.embeddings.base import Embeddings
+from distilabel.models.embeddings.sentence_transformers import (
+ SentenceTransformerEmbeddings,
+)
+from distilabel.models.embeddings.vllm import vLLMEmbeddings
+
+__all__ = [
+ "Embeddings",
+ "SentenceTransformerEmbeddings",
+ "vLLMEmbeddings",
+]
diff --git a/src/distilabel/llms.py b/src/distilabel/llms.py
new file mode 100644
index 0000000000..e4970992ce
--- /dev/null
+++ b/src/distilabel/llms.py
@@ -0,0 +1,68 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# ruff: noqa: E402
+
+import warnings
+
+deprecation_message = (
+ "Importing from 'distilabel.llms' is deprecated and will be removed in a version 1.7.0. "
+ "Import from 'distilabel.models' instead."
+)
+
+warnings.warn(deprecation_message, DeprecationWarning, stacklevel=2)
+
+from distilabel.models.llms.anthropic import AnthropicLLM
+from distilabel.models.llms.anyscale import AnyscaleLLM
+from distilabel.models.llms.azure import AzureOpenAILLM
+from distilabel.models.llms.base import LLM, AsyncLLM
+from distilabel.models.llms.cohere import CohereLLM
+from distilabel.models.llms.groq import GroqLLM
+from distilabel.models.llms.huggingface import InferenceEndpointsLLM, TransformersLLM
+from distilabel.models.llms.litellm import LiteLLM
+from distilabel.models.llms.llamacpp import LlamaCppLLM
+from distilabel.models.llms.mistral import MistralLLM
+from distilabel.models.llms.moa import MixtureOfAgentsLLM
+from distilabel.models.llms.ollama import OllamaLLM
+from distilabel.models.llms.openai import OpenAILLM
+from distilabel.models.llms.together import TogetherLLM
+from distilabel.models.llms.typing import GenerateOutput, HiddenState
+from distilabel.models.llms.vertexai import VertexAILLM
+from distilabel.models.llms.vllm import ClientvLLM, vLLM
+from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
+
+__all__ = [
+ "AnthropicLLM",
+ "AnyscaleLLM",
+ "AzureOpenAILLM",
+ "LLM",
+ "AsyncLLM",
+ "CohereLLM",
+ "GroqLLM",
+ "InferenceEndpointsLLM",
+ "LiteLLM",
+ "LlamaCppLLM",
+ "MistralLLM",
+ "CudaDevicePlacementMixin",
+ "MixtureOfAgentsLLM",
+ "OllamaLLM",
+ "OpenAILLM",
+ "TogetherLLM",
+ "TransformersLLM",
+ "GenerateOutput",
+ "HiddenState",
+ "VertexAILLM",
+ "ClientvLLM",
+ "vLLM",
+]
diff --git a/src/distilabel/llms/__init__.py b/src/distilabel/llms/__init__.py
deleted file mode 100644
index 526d6b1faf..0000000000
--- a/src/distilabel/llms/__init__.py
+++ /dev/null
@@ -1,57 +0,0 @@
-# Copyright 2023-present, Argilla, Inc.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from distilabel.llms.anthropic import AnthropicLLM
-from distilabel.llms.anyscale import AnyscaleLLM
-from distilabel.llms.azure import AzureOpenAILLM
-from distilabel.llms.base import LLM, AsyncLLM
-from distilabel.llms.cohere import CohereLLM
-from distilabel.llms.groq import GroqLLM
-from distilabel.llms.huggingface import InferenceEndpointsLLM, TransformersLLM
-from distilabel.llms.litellm import LiteLLM
-from distilabel.llms.llamacpp import LlamaCppLLM
-from distilabel.llms.mistral import MistralLLM
-from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
-from distilabel.llms.moa import MixtureOfAgentsLLM
-from distilabel.llms.ollama import OllamaLLM
-from distilabel.llms.openai import OpenAILLM
-from distilabel.llms.together import TogetherLLM
-from distilabel.llms.typing import GenerateOutput, HiddenState
-from distilabel.llms.vertexai import VertexAILLM
-from distilabel.llms.vllm import ClientvLLM, vLLM
-
-__all__ = [
- "AnthropicLLM",
- "AnyscaleLLM",
- "AzureOpenAILLM",
- "LLM",
- "AsyncLLM",
- "CohereLLM",
- "GroqLLM",
- "InferenceEndpointsLLM",
- "LiteLLM",
- "LlamaCppLLM",
- "MistralLLM",
- "CudaDevicePlacementMixin",
- "MixtureOfAgentsLLM",
- "OllamaLLM",
- "OpenAILLM",
- "TogetherLLM",
- "TransformersLLM",
- "GenerateOutput",
- "HiddenState",
- "VertexAILLM",
- "ClientvLLM",
- "vLLM",
-]
diff --git a/src/distilabel/models/__init__.py b/src/distilabel/models/__init__.py
new file mode 100644
index 0000000000..45807302f0
--- /dev/null
+++ b/src/distilabel/models/__init__.py
@@ -0,0 +1,66 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from distilabel.models.embeddings.base import Embeddings
+from distilabel.models.embeddings.sentence_transformers import (
+ SentenceTransformerEmbeddings,
+)
+from distilabel.models.embeddings.vllm import vLLMEmbeddings
+from distilabel.models.llms.anthropic import AnthropicLLM
+from distilabel.models.llms.anyscale import AnyscaleLLM
+from distilabel.models.llms.azure import AzureOpenAILLM
+from distilabel.models.llms.base import LLM, AsyncLLM
+from distilabel.models.llms.cohere import CohereLLM
+from distilabel.models.llms.groq import GroqLLM
+from distilabel.models.llms.huggingface import InferenceEndpointsLLM, TransformersLLM
+from distilabel.models.llms.litellm import LiteLLM
+from distilabel.models.llms.llamacpp import LlamaCppLLM
+from distilabel.models.llms.mistral import MistralLLM
+from distilabel.models.llms.moa import MixtureOfAgentsLLM
+from distilabel.models.llms.ollama import OllamaLLM
+from distilabel.models.llms.openai import OpenAILLM
+from distilabel.models.llms.together import TogetherLLM
+from distilabel.models.llms.typing import GenerateOutput, HiddenState
+from distilabel.models.llms.vertexai import VertexAILLM
+from distilabel.models.llms.vllm import ClientvLLM, vLLM
+from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
+
+__all__ = [
+ "AnthropicLLM",
+ "AnyscaleLLM",
+ "AzureOpenAILLM",
+ "LLM",
+ "AsyncLLM",
+ "CohereLLM",
+ "GroqLLM",
+ "InferenceEndpointsLLM",
+ "LiteLLM",
+ "LlamaCppLLM",
+ "MistralLLM",
+ "CudaDevicePlacementMixin",
+ "MixtureOfAgentsLLM",
+ "OllamaLLM",
+ "OpenAILLM",
+ "TogetherLLM",
+ "TransformersLLM",
+ "GenerateOutput",
+ "HiddenState",
+ "VertexAILLM",
+ "ClientvLLM",
+ "vLLM",
+ "Embeddings",
+ "SentenceTransformerEmbeddings",
+ "vLLMEmbeddings",
+]
diff --git a/src/distilabel/embeddings/__init__.py b/src/distilabel/models/embeddings/__init__.py
similarity index 75%
rename from src/distilabel/embeddings/__init__.py
rename to src/distilabel/models/embeddings/__init__.py
index 190ea70e50..9177298748 100644
--- a/src/distilabel/embeddings/__init__.py
+++ b/src/distilabel/models/embeddings/__init__.py
@@ -12,9 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from distilabel.embeddings.base import Embeddings
-from distilabel.embeddings.sentence_transformers import SentenceTransformerEmbeddings
-from distilabel.embeddings.vllm import vLLMEmbeddings
+from distilabel.models.embeddings.base import Embeddings
+from distilabel.models.embeddings.sentence_transformers import (
+ SentenceTransformerEmbeddings,
+)
+from distilabel.models.embeddings.vllm import vLLMEmbeddings
__all__ = [
"Embeddings",
diff --git a/src/distilabel/embeddings/base.py b/src/distilabel/models/embeddings/base.py
similarity index 100%
rename from src/distilabel/embeddings/base.py
rename to src/distilabel/models/embeddings/base.py
diff --git a/src/distilabel/embeddings/sentence_transformers.py b/src/distilabel/models/embeddings/sentence_transformers.py
similarity index 96%
rename from src/distilabel/embeddings/sentence_transformers.py
rename to src/distilabel/models/embeddings/sentence_transformers.py
index 85baea3de9..8c6e015027 100644
--- a/src/distilabel/embeddings/sentence_transformers.py
+++ b/src/distilabel/models/embeddings/sentence_transformers.py
@@ -16,9 +16,9 @@
from pydantic import Field, PrivateAttr
-from distilabel.embeddings.base import Embeddings
-from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.models.embeddings.base import Embeddings
+from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer
@@ -58,7 +58,7 @@ class SentenceTransformerEmbeddings(Embeddings, CudaDevicePlacementMixin):
Generating sentence embeddings:
```python
- from distilabel.embeddings import SentenceTransformerEmbeddings
+ from distilabel.models import SentenceTransformerEmbeddings
embeddings = SentenceTransformerEmbeddings(model="mixedbread-ai/mxbai-embed-large-v1")
diff --git a/src/distilabel/embeddings/vllm.py b/src/distilabel/models/embeddings/vllm.py
similarity index 95%
rename from src/distilabel/embeddings/vllm.py
rename to src/distilabel/models/embeddings/vllm.py
index cbbadd69af..8ddaccd7bb 100644
--- a/src/distilabel/embeddings/vllm.py
+++ b/src/distilabel/models/embeddings/vllm.py
@@ -16,9 +16,9 @@
from pydantic import Field, PrivateAttr
-from distilabel.embeddings.base import Embeddings
-from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.models.embeddings.base import Embeddings
+from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
if TYPE_CHECKING:
from vllm import LLM as _vLLM
@@ -49,7 +49,7 @@ class vLLMEmbeddings(Embeddings, CudaDevicePlacementMixin):
Generating sentence embeddings:
```python
- from distilabel.embeddings import vLLMEmbeddings
+ from distilabel.models import vLLMEmbeddings
embeddings = vLLMEmbeddings(model="intfloat/e5-mistral-7b-instruct")
diff --git a/src/distilabel/models/llms/__init__.py b/src/distilabel/models/llms/__init__.py
new file mode 100644
index 0000000000..2ae3119832
--- /dev/null
+++ b/src/distilabel/models/llms/__init__.py
@@ -0,0 +1,57 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from distilabel.models.llms.anthropic import AnthropicLLM
+from distilabel.models.llms.anyscale import AnyscaleLLM
+from distilabel.models.llms.azure import AzureOpenAILLM
+from distilabel.models.llms.base import LLM, AsyncLLM
+from distilabel.models.llms.cohere import CohereLLM
+from distilabel.models.llms.groq import GroqLLM
+from distilabel.models.llms.huggingface import InferenceEndpointsLLM, TransformersLLM
+from distilabel.models.llms.litellm import LiteLLM
+from distilabel.models.llms.llamacpp import LlamaCppLLM
+from distilabel.models.llms.mistral import MistralLLM
+from distilabel.models.llms.moa import MixtureOfAgentsLLM
+from distilabel.models.llms.ollama import OllamaLLM
+from distilabel.models.llms.openai import OpenAILLM
+from distilabel.models.llms.together import TogetherLLM
+from distilabel.models.llms.typing import GenerateOutput, HiddenState
+from distilabel.models.llms.vertexai import VertexAILLM
+from distilabel.models.llms.vllm import ClientvLLM, vLLM
+from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
+
+__all__ = [
+ "AnthropicLLM",
+ "AnyscaleLLM",
+ "AzureOpenAILLM",
+ "LLM",
+ "AsyncLLM",
+ "CohereLLM",
+ "GroqLLM",
+ "InferenceEndpointsLLM",
+ "LiteLLM",
+ "LlamaCppLLM",
+ "MistralLLM",
+ "CudaDevicePlacementMixin",
+ "MixtureOfAgentsLLM",
+ "OllamaLLM",
+ "OpenAILLM",
+ "TogetherLLM",
+ "TransformersLLM",
+ "GenerateOutput",
+ "HiddenState",
+ "VertexAILLM",
+ "ClientvLLM",
+ "vLLM",
+]
diff --git a/src/distilabel/llms/_dummy.py b/src/distilabel/models/llms/_dummy.py
similarity index 91%
rename from src/distilabel/llms/_dummy.py
rename to src/distilabel/models/llms/_dummy.py
index 740f98cd46..de89356d0f 100644
--- a/src/distilabel/llms/_dummy.py
+++ b/src/distilabel/models/llms/_dummy.py
@@ -14,11 +14,11 @@
from typing import TYPE_CHECKING, Any, List
-from distilabel.llms.base import LLM, AsyncLLM
-from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
+from distilabel.models.llms.base import LLM, AsyncLLM
+from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
if TYPE_CHECKING:
- from distilabel.llms.typing import GenerateOutput
+ from distilabel.models.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import FormattedInput
diff --git a/src/distilabel/llms/anthropic.py b/src/distilabel/models/llms/anthropic.py
similarity index 97%
rename from src/distilabel/llms/anthropic.py
rename to src/distilabel/models/llms/anthropic.py
index f938da58d2..7cd3cbcd3f 100644
--- a/src/distilabel/llms/anthropic.py
+++ b/src/distilabel/models/llms/anthropic.py
@@ -27,9 +27,9 @@
from httpx import AsyncClient
from pydantic import Field, PrivateAttr, SecretStr, validate_call
-from distilabel.llms.base import AsyncLLM
-from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.models.llms.base import AsyncLLM
+from distilabel.models.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import (
FormattedInput,
InstructorStructuredOutputType,
@@ -78,7 +78,7 @@ class AnthropicLLM(AsyncLLM):
Generate text:
```python
- from distilabel.llms import AnthropicLLM
+ from distilabel.models.llms import AnthropicLLM
llm = AnthropicLLM(model="claude-3-opus-20240229", api_key="api.key")
@@ -91,7 +91,7 @@ class AnthropicLLM(AsyncLLM):
```python
from pydantic import BaseModel
- from distilabel.llms import AnthropicLLM
+ from distilabel.models.llms import AnthropicLLM
class User(BaseModel):
name: str
diff --git a/src/distilabel/llms/anyscale.py b/src/distilabel/models/llms/anyscale.py
similarity index 96%
rename from src/distilabel/llms/anyscale.py
rename to src/distilabel/models/llms/anyscale.py
index 1d4114d383..0029615f2b 100644
--- a/src/distilabel/llms/anyscale.py
+++ b/src/distilabel/models/llms/anyscale.py
@@ -17,8 +17,8 @@
from pydantic import Field, PrivateAttr, SecretStr
-from distilabel.llms.openai import OpenAILLM
from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.models.llms.openai import OpenAILLM
_ANYSCALE_API_KEY_ENV_VAR_NAME = "ANYSCALE_API_KEY"
@@ -43,7 +43,7 @@ class AnyscaleLLM(OpenAILLM):
Generate text:
```python
- from distilabel.llms import AnyscaleLLM
+ from distilabel.models.llms import AnyscaleLLM
llm = AnyscaleLLM(model="google/gemma-7b-it", api_key="api.key")
diff --git a/src/distilabel/llms/azure.py b/src/distilabel/models/llms/azure.py
similarity index 94%
rename from src/distilabel/llms/azure.py
rename to src/distilabel/models/llms/azure.py
index 58ed15010f..964612f372 100644
--- a/src/distilabel/llms/azure.py
+++ b/src/distilabel/models/llms/azure.py
@@ -19,8 +19,8 @@
from pydantic import Field, PrivateAttr, SecretStr
from typing_extensions import override
-from distilabel.llms.openai import OpenAILLM
from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.models.llms.openai import OpenAILLM
if TYPE_CHECKING:
from openai import AsyncAzureOpenAI
@@ -51,7 +51,7 @@ class AzureOpenAILLM(OpenAILLM):
Generate text:
```python
- from distilabel.llms import AzureOpenAILLM
+ from distilabel.models.llms import AzureOpenAILLM
llm = AzureOpenAILLM(model="gpt-4-turbo", api_key="api.key")
@@ -63,7 +63,7 @@ class AzureOpenAILLM(OpenAILLM):
Generate text from a custom endpoint following the OpenAI API:
```python
- from distilabel.llms import AzureOpenAILLM
+ from distilabel.models.llms import AzureOpenAILLM
llm = AzureOpenAILLM(
model="prometheus-eval/prometheus-7b-v2.0",
@@ -79,7 +79,7 @@ class AzureOpenAILLM(OpenAILLM):
```python
from pydantic import BaseModel
- from distilabel.llms import AzureOpenAILLM
+ from distilabel.models.llms import AzureOpenAILLM
class User(BaseModel):
name: str
@@ -122,7 +122,7 @@ def load(self) -> None:
# This is a workaround to avoid the `OpenAILLM` calling the _prepare_structured_output
# in the load method before we have the proper client.
with patch(
- "distilabel.llms.openai.OpenAILLM._prepare_structured_output", lambda x: x
+ "distilabel.models.openai.OpenAILLM._prepare_structured_output", lambda x: x
):
super().load()
diff --git a/src/distilabel/llms/base.py b/src/distilabel/models/llms/base.py
similarity index 99%
rename from src/distilabel/llms/base.py
rename to src/distilabel/models/llms/base.py
index ced6a8e041..58ca3b5f62 100644
--- a/src/distilabel/llms/base.py
+++ b/src/distilabel/models/llms/base.py
@@ -40,11 +40,11 @@
if TYPE_CHECKING:
from logging import Logger
- from distilabel.llms.typing import GenerateOutput, HiddenState
from distilabel.mixins.runtime_parameters import (
RuntimeParameterInfo,
RuntimeParametersNames,
)
+ from distilabel.models.llms.typing import GenerateOutput, HiddenState
from distilabel.steps.tasks.structured_outputs.outlines import StructuredOutputType
from distilabel.steps.tasks.typing import (
FormattedInput,
diff --git a/src/distilabel/llms/cohere.py b/src/distilabel/models/llms/cohere.py
similarity index 98%
rename from src/distilabel/llms/cohere.py
rename to src/distilabel/models/llms/cohere.py
index e9d0d0c0f2..80fbddf4f7 100644
--- a/src/distilabel/llms/cohere.py
+++ b/src/distilabel/models/llms/cohere.py
@@ -25,9 +25,9 @@
from pydantic import Field, PrivateAttr, SecretStr, validate_call
-from distilabel.llms.base import AsyncLLM
-from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.models.llms.base import AsyncLLM
+from distilabel.models.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import (
FormattedInput,
InstructorStructuredOutputType,
@@ -73,7 +73,7 @@ class CohereLLM(AsyncLLM):
Generate text:
```python
- from distilabel.llms import CohereLLM
+ from distilabel.models.llms import CohereLLM
llm = CohereLLM(model="CohereForAI/c4ai-command-r-plus")
@@ -86,7 +86,7 @@ class CohereLLM(AsyncLLM):
```python
from pydantic import BaseModel
- from distilabel.llms import CohereLLM
+ from distilabel.models.llms import CohereLLM
class User(BaseModel):
name: str
diff --git a/src/distilabel/llms/groq.py b/src/distilabel/models/llms/groq.py
similarity index 97%
rename from src/distilabel/llms/groq.py
rename to src/distilabel/models/llms/groq.py
index c4c2554329..92ff9b8b35 100644
--- a/src/distilabel/llms/groq.py
+++ b/src/distilabel/models/llms/groq.py
@@ -17,8 +17,8 @@
from pydantic import Field, PrivateAttr, SecretStr, validate_call
-from distilabel.llms.base import AsyncLLM
-from distilabel.llms.typing import GenerateOutput
+from distilabel.models.llms.base import AsyncLLM
+from distilabel.models.llms.typing import GenerateOutput
from distilabel.steps.base import RuntimeParameter
from distilabel.steps.tasks.typing import (
FormattedInput,
@@ -66,7 +66,7 @@ class GroqLLM(AsyncLLM):
Generate text:
```python
- from distilabel.llms import GroqLLM
+ from distilabel.models.llms import GroqLLM
llm = GroqLLM(model="llama3-70b-8192")
@@ -79,7 +79,7 @@ class GroqLLM(AsyncLLM):
```python
from pydantic import BaseModel
- from distilabel.llms import GroqLLM
+ from distilabel.models.llms import GroqLLM
class User(BaseModel):
name: str
diff --git a/src/distilabel/llms/huggingface/__init__.py b/src/distilabel/models/llms/huggingface/__init__.py
similarity index 79%
rename from src/distilabel/llms/huggingface/__init__.py
rename to src/distilabel/models/llms/huggingface/__init__.py
index a88cf2ccfd..beca525bce 100644
--- a/src/distilabel/llms/huggingface/__init__.py
+++ b/src/distilabel/models/llms/huggingface/__init__.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from distilabel.llms.huggingface.inference_endpoints import InferenceEndpointsLLM
-from distilabel.llms.huggingface.transformers import TransformersLLM
+from distilabel.models.llms.huggingface.inference_endpoints import InferenceEndpointsLLM
+from distilabel.models.llms.huggingface.transformers import TransformersLLM
__all__ = ["InferenceEndpointsLLM", "TransformersLLM"]
diff --git a/src/distilabel/llms/huggingface/inference_endpoints.py b/src/distilabel/models/llms/huggingface/inference_endpoints.py
similarity index 98%
rename from src/distilabel/llms/huggingface/inference_endpoints.py
rename to src/distilabel/models/llms/huggingface/inference_endpoints.py
index 3566228f56..3f4bc1856b 100644
--- a/src/distilabel/llms/huggingface/inference_endpoints.py
+++ b/src/distilabel/models/llms/huggingface/inference_endpoints.py
@@ -29,10 +29,10 @@
from pydantic._internal._model_construction import ModelMetaclass
from typing_extensions import Annotated, override
-from distilabel.llms.base import AsyncLLM
-from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
-from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.models.llms.base import AsyncLLM
+from distilabel.models.llms.typing import GenerateOutput
+from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.steps.tasks.typing import (
FormattedInput,
StandardInput,
@@ -78,7 +78,7 @@ class InferenceEndpointsLLM(AsyncLLM, MagpieChatTemplateMixin):
Free serverless Inference API, set the input_batch_size of the Task that uses this to avoid Model is overloaded:
```python
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models.llms.huggingface import InferenceEndpointsLLM
llm = InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
@@ -92,7 +92,7 @@ class InferenceEndpointsLLM(AsyncLLM, MagpieChatTemplateMixin):
Dedicated Inference Endpoints:
```python
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models.llms.huggingface import InferenceEndpointsLLM
llm = InferenceEndpointsLLM(
endpoint_name="",
@@ -108,7 +108,7 @@ class InferenceEndpointsLLM(AsyncLLM, MagpieChatTemplateMixin):
Dedicated Inference Endpoints or TGI:
```python
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models.llms.huggingface import InferenceEndpointsLLM
llm = InferenceEndpointsLLM(
api_key="",
@@ -124,7 +124,7 @@ class InferenceEndpointsLLM(AsyncLLM, MagpieChatTemplateMixin):
```python
from pydantic import BaseModel
- from distilabel.llms import InferenceEndpointsLLM
+ from distilabel.models.llms import InferenceEndpointsLLM
class User(BaseModel):
name: str
diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/models/llms/huggingface/transformers.py
similarity index 96%
rename from src/distilabel/llms/huggingface/transformers.py
rename to src/distilabel/models/llms/huggingface/transformers.py
index 27ab00e5b9..e34731a21b 100644
--- a/src/distilabel/llms/huggingface/transformers.py
+++ b/src/distilabel/models/llms/huggingface/transformers.py
@@ -17,11 +17,11 @@
from pydantic import Field, PrivateAttr, SecretStr, validate_call
-from distilabel.llms.base import LLM
-from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
-from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
-from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.models.llms.base import LLM
+from distilabel.models.llms.typing import GenerateOutput
+from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
+from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.steps.tasks.typing import OutlinesStructuredOutputType, StandardInput
from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR
@@ -30,7 +30,7 @@
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
- from distilabel.llms.typing import HiddenState
+ from distilabel.models.llms.typing import HiddenState
class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
@@ -79,7 +79,7 @@ class TransformersLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
Generate text:
```python
- from distilabel.llms import TransformersLLM
+ from distilabel.models.llms import TransformersLLM
llm = TransformersLLM(model="microsoft/Phi-3-mini-4k-instruct")
diff --git a/src/distilabel/llms/litellm.py b/src/distilabel/models/llms/litellm.py
similarity index 98%
rename from src/distilabel/llms/litellm.py
rename to src/distilabel/models/llms/litellm.py
index 48361ef706..1852d76775 100644
--- a/src/distilabel/llms/litellm.py
+++ b/src/distilabel/models/llms/litellm.py
@@ -17,9 +17,9 @@
from pydantic import Field, PrivateAttr, validate_call
-from distilabel.llms.base import AsyncLLM
-from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.models.llms.base import AsyncLLM
+from distilabel.models.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType
if TYPE_CHECKING:
@@ -44,7 +44,7 @@ class LiteLLM(AsyncLLM):
Generate text:
```python
- from distilabel.llms import LiteLLM
+ from distilabel.models.llms import LiteLLM
llm = LiteLLM(model="gpt-3.5-turbo")
@@ -57,7 +57,7 @@ class LiteLLM(AsyncLLM):
```python
from pydantic import BaseModel
- from distilabel.llms import LiteLLM
+ from distilabel.models.llms import LiteLLM
class User(BaseModel):
name: str
diff --git a/src/distilabel/llms/llamacpp.py b/src/distilabel/models/llms/llamacpp.py
similarity index 98%
rename from src/distilabel/llms/llamacpp.py
rename to src/distilabel/models/llms/llamacpp.py
index 9d158ea525..20b66f8cfe 100644
--- a/src/distilabel/llms/llamacpp.py
+++ b/src/distilabel/models/llms/llamacpp.py
@@ -16,9 +16,9 @@
from pydantic import Field, FilePath, PrivateAttr, validate_call
-from distilabel.llms.base import LLM
-from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.models.llms.base import LLM
+from distilabel.models.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType
if TYPE_CHECKING:
@@ -63,7 +63,7 @@ class LlamaCppLLM(LLM):
```python
from pathlib import Path
- from distilabel.llms import LlamaCppLLM
+ from distilabel.models.llms import LlamaCppLLM
# You can follow along this example downloading the following model running the following
# command in the terminal, that will download the model to the `Downloads` folder:
@@ -87,7 +87,7 @@ class LlamaCppLLM(LLM):
```python
from pathlib import Path
- from distilabel.llms import LlamaCppLLM
+ from distilabel.models.llms import LlamaCppLLM
model_path = "Downloads/openhermes-2.5-mistral-7b.Q4_K_M.gguf"
diff --git a/src/distilabel/llms/mistral.py b/src/distilabel/models/llms/mistral.py
similarity index 97%
rename from src/distilabel/llms/mistral.py
rename to src/distilabel/models/llms/mistral.py
index a913d6ad0a..5848402757 100644
--- a/src/distilabel/llms/mistral.py
+++ b/src/distilabel/models/llms/mistral.py
@@ -17,9 +17,9 @@
from pydantic import Field, PrivateAttr, SecretStr, validate_call
-from distilabel.llms.base import AsyncLLM
-from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.models.llms.base import AsyncLLM
+from distilabel.models.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import (
FormattedInput,
InstructorStructuredOutputType,
@@ -65,7 +65,7 @@ class MistralLLM(AsyncLLM):
Generate text:
```python
- from distilabel.llms import MistralLLM
+ from distilabel.models.llms import MistralLLM
llm = MistralLLM(model="open-mixtral-8x22b")
@@ -78,7 +78,7 @@ class MistralLLM(AsyncLLM):
```python
from pydantic import BaseModel
- from distilabel.llms import MistralLLM
+ from distilabel.models.llms import MistralLLM
class User(BaseModel):
name: str
diff --git a/src/distilabel/llms/moa.py b/src/distilabel/models/llms/moa.py
similarity index 98%
rename from src/distilabel/llms/moa.py
rename to src/distilabel/models/llms/moa.py
index a7dd5db19e..11af619ad4 100644
--- a/src/distilabel/llms/moa.py
+++ b/src/distilabel/models/llms/moa.py
@@ -18,12 +18,12 @@
from pydantic import Field
-from distilabel.llms.base import LLM, AsyncLLM
+from distilabel.models.llms.base import LLM, AsyncLLM
from distilabel.steps.tasks.typing import StandardInput
if TYPE_CHECKING:
- from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParametersNames
+ from distilabel.models.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import FormattedInput
# Mixture-of-Agents system prompt from the paper with the addition instructing the LLM
@@ -64,7 +64,7 @@ class MixtureOfAgentsLLM(AsyncLLM):
Generate text:
```python
- from distilabel.llms import MixtureOfAgentsLLM, InferenceEndpointsLLM
+ from distilabel.models.llms import MixtureOfAgentsLLM, InferenceEndpointsLLM
llm = MixtureOfAgentsLLM(
aggregator_llm=InferenceEndpointsLLM(
diff --git a/src/distilabel/llms/ollama.py b/src/distilabel/models/llms/ollama.py
similarity index 97%
rename from src/distilabel/llms/ollama.py
rename to src/distilabel/models/llms/ollama.py
index fc3abd605b..009d336aed 100644
--- a/src/distilabel/llms/ollama.py
+++ b/src/distilabel/models/llms/ollama.py
@@ -17,9 +17,9 @@
from pydantic import Field, PrivateAttr, validate_call
from typing_extensions import TypedDict
-from distilabel.llms.base import AsyncLLM
-from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.models.llms.base import AsyncLLM
+from distilabel.models.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import InstructorStructuredOutputType, StandardInput
if TYPE_CHECKING:
@@ -84,7 +84,7 @@ class OllamaLLM(AsyncLLM):
Generate text:
```python
- from distilabel.llms import OllamaLLM
+ from distilabel.models.llms import OllamaLLM
llm = OllamaLLM(model="llama3")
diff --git a/src/distilabel/llms/openai.py b/src/distilabel/models/llms/openai.py
similarity index 98%
rename from src/distilabel/llms/openai.py
rename to src/distilabel/models/llms/openai.py
index 48cac8a50e..3bcca14cad 100644
--- a/src/distilabel/llms/openai.py
+++ b/src/distilabel/models/llms/openai.py
@@ -21,9 +21,9 @@
from distilabel import envs
from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException
-from distilabel.llms.base import AsyncLLM
-from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.models.llms.base import AsyncLLM
+from distilabel.models.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import FormattedInput, InstructorStructuredOutputType
if TYPE_CHECKING:
@@ -74,7 +74,7 @@ class OpenAILLM(AsyncLLM):
Generate text:
```python
- from distilabel.llms import OpenAILLM
+ from distilabel.models.llms import OpenAILLM
llm = OpenAILLM(model="gpt-4-turbo", api_key="api.key")
@@ -86,7 +86,7 @@ class OpenAILLM(AsyncLLM):
Generate text from a custom endpoint following the OpenAI API:
```python
- from distilabel.llms import OpenAILLM
+ from distilabel.models.llms import OpenAILLM
llm = OpenAILLM(
model="prometheus-eval/prometheus-7b-v2.0",
@@ -102,7 +102,7 @@ class OpenAILLM(AsyncLLM):
```python
from pydantic import BaseModel
- from distilabel.llms import OpenAILLM
+ from distilabel.models.llms import OpenAILLM
class User(BaseModel):
name: str
@@ -123,7 +123,7 @@ class User(BaseModel):
Generate with Batch API (offline batch generation):
```python
- from distilabel.llms import OpenAILLM
+ from distilabel.models.llms import OpenAILLM
load = llm = OpenAILLM(
model="gpt-3.5-turbo",
diff --git a/src/distilabel/llms/together.py b/src/distilabel/models/llms/together.py
similarity index 96%
rename from src/distilabel/llms/together.py
rename to src/distilabel/models/llms/together.py
index 88e7fd7647..a80183b07f 100644
--- a/src/distilabel/llms/together.py
+++ b/src/distilabel/models/llms/together.py
@@ -17,8 +17,8 @@
from pydantic import Field, PrivateAttr, SecretStr
-from distilabel.llms.openai import OpenAILLM
from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.models.llms.openai import OpenAILLM
_TOGETHER_API_KEY_ENV_VAR_NAME = "TOGETHER_API_KEY"
@@ -42,7 +42,7 @@ class TogetherLLM(OpenAILLM):
Generate text:
```python
- from distilabel.llms import AnyscaleLLM
+ from distilabel.models.llms import AnyscaleLLM
llm = TogetherLLM(model="mistralai/Mixtral-8x7B-Instruct-v0.1", api_key="api.key")
diff --git a/src/distilabel/llms/typing.py b/src/distilabel/models/llms/typing.py
similarity index 100%
rename from src/distilabel/llms/typing.py
rename to src/distilabel/models/llms/typing.py
diff --git a/src/distilabel/llms/vertexai.py b/src/distilabel/models/llms/vertexai.py
similarity index 97%
rename from src/distilabel/llms/vertexai.py
rename to src/distilabel/models/llms/vertexai.py
index 0c49fa3931..357a3817e4 100644
--- a/src/distilabel/llms/vertexai.py
+++ b/src/distilabel/models/llms/vertexai.py
@@ -16,8 +16,8 @@
from pydantic import PrivateAttr, validate_call
-from distilabel.llms.base import AsyncLLM
-from distilabel.llms.typing import GenerateOutput
+from distilabel.models.llms.base import AsyncLLM
+from distilabel.models.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import StandardInput
if TYPE_CHECKING:
@@ -48,7 +48,7 @@ class VertexAILLM(AsyncLLM):
Generate text:
```python
- from distilabel.llms import VertexAILLM
+ from distilabel.models.llms import VertexAILLM
llm = VertexAILLM(model="gemini-1.5-pro")
diff --git a/src/distilabel/llms/vllm.py b/src/distilabel/models/llms/vllm.py
similarity index 98%
rename from src/distilabel/llms/vllm.py
rename to src/distilabel/models/llms/vllm.py
index 19212755d4..417aadabed 100644
--- a/src/distilabel/llms/vllm.py
+++ b/src/distilabel/models/llms/vllm.py
@@ -29,12 +29,12 @@
import numpy as np
from pydantic import Field, PrivateAttr, SecretStr, validate_call
-from distilabel.llms.base import LLM
-from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
-from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
-from distilabel.llms.openai import OpenAILLM
-from distilabel.llms.typing import GenerateOutput
from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.models.llms.base import LLM
+from distilabel.models.llms.openai import OpenAILLM
+from distilabel.models.llms.typing import GenerateOutput
+from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
+from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.steps.tasks.typing import FormattedInput, OutlinesStructuredOutputType
if TYPE_CHECKING:
@@ -102,7 +102,7 @@ class vLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
Generate text:
```python
- from distilabel.llms import vLLM
+ from distilabel.models.llms import vLLM
# You can pass a custom chat_template to the model
llm = vLLM(
@@ -120,7 +120,7 @@ class vLLM(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin):
```python
from pathlib import Path
- from distilabel.llms import vLLM
+ from distilabel.models.llms import vLLM
class User(BaseModel):
name: str
@@ -453,7 +453,7 @@ class ClientvLLM(OpenAILLM, MagpieChatTemplateMixin):
Generate text:
```python
- from distilabel.llms import ClientvLLM
+ from distilabel.models.llms import ClientvLLM
llm = ClientvLLM(
base_url="http://localhost:8000/v1",
diff --git a/src/distilabel/llms/mixins/__init__.py b/src/distilabel/models/mixins/__init__.py
similarity index 100%
rename from src/distilabel/llms/mixins/__init__.py
rename to src/distilabel/models/mixins/__init__.py
diff --git a/src/distilabel/llms/mixins/cuda_device_placement.py b/src/distilabel/models/mixins/cuda_device_placement.py
similarity index 100%
rename from src/distilabel/llms/mixins/cuda_device_placement.py
rename to src/distilabel/models/mixins/cuda_device_placement.py
diff --git a/src/distilabel/llms/mixins/magpie.py b/src/distilabel/models/mixins/magpie.py
similarity index 100%
rename from src/distilabel/llms/mixins/magpie.py
rename to src/distilabel/models/mixins/magpie.py
diff --git a/src/distilabel/pipeline/base.py b/src/distilabel/pipeline/base.py
index c3397c392c..b9a3d4bdc2 100644
--- a/src/distilabel/pipeline/base.py
+++ b/src/distilabel/pipeline/base.py
@@ -281,6 +281,7 @@ def run(
storage_parameters: Optional[Dict[str, Any]] = None,
use_fs_to_pass_data: bool = False,
dataset: Optional["InputDataset"] = None,
+ dataset_batch_size: int = 50,
logging_handlers: Optional[List[logging.Handler]] = None,
) -> "Distiset": # type: ignore
"""Run the pipeline. It will set the runtime parameters for the steps and validate
@@ -308,6 +309,8 @@ def run(
dataset: If given, it will be used to create a `GeneratorStep` and put it as the
root step. Convenient method when you have already processed the dataset in
your script and just want to pass it already processed. Defaults to `None`.
+ dataset_batch_size: if `dataset` is given, this will be the size of the batches
+ yield by the `GeneratorStep` created using the `dataset`. Defaults to `50`.
logging_handlers: A list of logging handlers that will be used to log the
output of the pipeline. This argument can be useful so the logging messages
can be extracted and used in a different context. Defaults to `None`.
@@ -326,7 +329,7 @@ def run(
self._refresh_pipeline_from_cache()
if dataset is not None:
- self._add_dataset_generator_step(dataset)
+ self._add_dataset_generator_step(dataset, dataset_batch_size)
setup_logging(
log_queue=self._log_queue,
@@ -427,13 +430,16 @@ def dry_run(
self._dry_run = False
return distiset
- def _add_dataset_generator_step(self, dataset: "InputDataset") -> None:
+ def _add_dataset_generator_step(
+ self, dataset: "InputDataset", batch_size: int = 50
+ ) -> None:
"""Create a root step to work as the `GeneratorStep` for the pipeline using a
dataset.
Args:
dataset: A dataset that will be used to create a `GeneratorStep` and
placed in the DAG as the root step.
+ batch_size: The size of the batches generated by the `GeneratorStep`.
Raises:
ValueError: If there's already a `GeneratorStep` in the pipeline.
@@ -447,7 +453,11 @@ def _add_dataset_generator_step(self, dataset: "InputDataset") -> None:
f" `GeneratorStep`: {step}",
page="sections/how_to_guides/basic/step/#types-of-steps",
)
- loader = make_generator_step(dataset, self)
+ loader = make_generator_step(
+ dataset=dataset,
+ pipeline=self,
+ batch_size=batch_size,
+ )
self.dag.add_root_step(loader)
def get_runtime_parameters_info(self) -> "PipelineRuntimeParametersInfo":
diff --git a/src/distilabel/pipeline/local.py b/src/distilabel/pipeline/local.py
index c01cce303f..a35100e156 100644
--- a/src/distilabel/pipeline/local.py
+++ b/src/distilabel/pipeline/local.py
@@ -152,6 +152,7 @@ def run(
storage_parameters: Optional[Dict[str, Any]] = None,
use_fs_to_pass_data: bool = False,
dataset: Optional["InputDataset"] = None,
+ dataset_batch_size: int = 50,
logging_handlers: Optional[List["logging.Handler"]] = None,
) -> "Distiset":
"""Runs the pipeline.
@@ -175,6 +176,8 @@ def run(
dataset: If given, it will be used to create a `GeneratorStep` and put it as the
root step. Convenient method when you have already processed the dataset in
your script and just want to pass it already processed. Defaults to `None`.
+ dataset_batch_size: if `dataset` is given, this will be the size of the batches
+ yield by the `GeneratorStep` created using the `dataset`. Defaults to `50`.
logging_handlers: A list of logging handlers that will be used to log the
output of the pipeline. This argument can be useful so the logging messages
can be extracted and used in a different context. Defaults to `None`.
@@ -193,6 +196,7 @@ def run(
storage_parameters=storage_parameters,
use_fs_to_pass_data=use_fs_to_pass_data,
dataset=dataset,
+ dataset_batch_size=dataset_batch_size,
)
self._log_queue = cast("Queue[Any]", mp.Queue())
@@ -203,6 +207,7 @@ def run(
storage_parameters=storage_parameters,
use_fs_to_pass_data=use_fs_to_pass_data,
dataset=dataset,
+ dataset_batch_size=dataset_batch_size,
logging_handlers=logging_handlers,
):
return distiset
diff --git a/src/distilabel/pipeline/ray.py b/src/distilabel/pipeline/ray.py
index 70bf205ab3..4b8ff509e3 100644
--- a/src/distilabel/pipeline/ray.py
+++ b/src/distilabel/pipeline/ray.py
@@ -18,7 +18,7 @@
from distilabel.constants import INPUT_QUEUE_ATTR_NAME, STEP_ATTR_NAME
from distilabel.distiset import create_distiset
from distilabel.errors import DistilabelUserError
-from distilabel.llms.vllm import vLLM
+from distilabel.models.llms.vllm import vLLM
from distilabel.pipeline.base import BasePipeline, set_pipeline_running_env_variables
from distilabel.pipeline.step_wrapper import _StepWrapper
from distilabel.utils.logging import setup_logging, stop_logging
@@ -83,6 +83,7 @@ def run(
storage_parameters: Optional[Dict[str, Any]] = None,
use_fs_to_pass_data: bool = False,
dataset: Optional["InputDataset"] = None,
+ dataset_batch_size: int = 50,
logging_handlers: Optional[List["logging.Handler"]] = None,
) -> "Distiset":
"""Runs the pipeline in the Ray cluster.
@@ -106,6 +107,8 @@ def run(
dataset: If given, it will be used to create a `GeneratorStep` and put it as the
root step. Convenient method when you have already processed the dataset in
your script and just want to pass it already processed. Defaults to `None`.
+ dataset_batch_size: if `dataset` is given, this will be the size of the batches
+ yield by the `GeneratorStep` created using the `dataset`. Defaults to `50`.
logging_handlers: A list of logging handlers that will be used to log the
output of the pipeline. This argument can be useful so the logging messages
can be extracted and used in a different context. Defaults to `None`.
@@ -130,6 +133,7 @@ def run(
storage_parameters=storage_parameters,
use_fs_to_pass_data=use_fs_to_pass_data,
dataset=dataset,
+ dataset_batch_size=dataset_batch_size,
logging_handlers=logging_handlers,
):
return distiset
diff --git a/src/distilabel/pipeline/routing_batch_function.py b/src/distilabel/pipeline/routing_batch_function.py
index e29a520405..3f0aaf9ff4 100644
--- a/src/distilabel/pipeline/routing_batch_function.py
+++ b/src/distilabel/pipeline/routing_batch_function.py
@@ -252,7 +252,7 @@ def routing_batch_function(
Example:
```python
- from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM
+ from distilabel.models import MistralLLM, OpenAILLM, VertexAILLM
from distilabel.pipeline import Pipeline, routing_batch_function
from distilabel.steps import LoadDataFromHub, GroupColumns
@@ -337,7 +337,7 @@ def sample_n_steps(n: int) -> RoutingBatchFunction:
Example:
```python
- from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM
+ from distilabel.models import MistralLLM, OpenAILLM, VertexAILLM
from distilabel.pipeline import Pipeline, sample_n_steps
from distilabel.steps import LoadDataFromHub, GroupColumns
diff --git a/src/distilabel/pipeline/step_wrapper.py b/src/distilabel/pipeline/step_wrapper.py
index 844648f202..8b33da933d 100644
--- a/src/distilabel/pipeline/step_wrapper.py
+++ b/src/distilabel/pipeline/step_wrapper.py
@@ -19,7 +19,7 @@
from distilabel.constants import LAST_BATCH_SENT_FLAG
from distilabel.errors import DISTILABEL_DOCS_URL
from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException
-from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
+from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.pipeline.batch import _Batch
from distilabel.pipeline.typing import StepLoadStatus
from distilabel.steps.base import GeneratorStep, Step, _Step
diff --git a/src/distilabel/steps/clustering/text_clustering.py b/src/distilabel/steps/clustering/text_clustering.py
index 7e640bf5c1..925ffab229 100644
--- a/src/distilabel/steps/clustering/text_clustering.py
+++ b/src/distilabel/steps/clustering/text_clustering.py
@@ -74,7 +74,7 @@ class TextClustering(TextClassification, GlobalTask):
Generate labels for a set of texts using clustering:
```python
- from distilabel.llms import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
from distilabel.steps import UMAP, DBSCAN, TextClustering
from distilabel.pipeline import Pipeline
diff --git a/src/distilabel/steps/decorator.py b/src/distilabel/steps/decorator.py
index 0816ca13eb..3e84df66f2 100644
--- a/src/distilabel/steps/decorator.py
+++ b/src/distilabel/steps/decorator.py
@@ -17,7 +17,6 @@
TYPE_CHECKING,
Any,
Callable,
- List,
Literal,
Type,
Union,
@@ -175,10 +174,10 @@ def decorator(func: ProcessingFunc) -> Type["_Step"]:
**runtime_parameters, # type: ignore
)
- def inputs_property(self) -> List[str]:
+ def inputs_property(self) -> "StepColumns":
return inputs
- def outputs_property(self) -> List[str]:
+ def outputs_property(self) -> "StepColumns":
return outputs
def process(
diff --git a/src/distilabel/steps/embeddings/embedding_generation.py b/src/distilabel/steps/embeddings/embedding_generation.py
index 8db3bee2ee..0aeed03102 100644
--- a/src/distilabel/steps/embeddings/embedding_generation.py
+++ b/src/distilabel/steps/embeddings/embedding_generation.py
@@ -14,7 +14,7 @@
from typing import TYPE_CHECKING
-from distilabel.embeddings.base import Embeddings
+from distilabel.models.embeddings.base import Embeddings
from distilabel.steps.base import Step, StepInput
if TYPE_CHECKING:
@@ -43,7 +43,7 @@ class EmbeddingGeneration(Step):
Generate sentence embeddings with Sentence Transformers:
```python
- from distilabel.embeddings import SentenceTransformerEmbeddings
+ from distilabel.models import SentenceTransformerEmbeddings
from distilabel.steps import EmbeddingGeneration
embedding_generation = EmbeddingGeneration(
diff --git a/src/distilabel/steps/embeddings/nearest_neighbour.py b/src/distilabel/steps/embeddings/nearest_neighbour.py
index 98b646d9ee..df5f48f8fa 100644
--- a/src/distilabel/steps/embeddings/nearest_neighbour.py
+++ b/src/distilabel/steps/embeddings/nearest_neighbour.py
@@ -84,7 +84,7 @@ class FaissNearestNeighbour(GlobalStep):
Generating embeddings and getting the nearest neighbours:
```python
- from distilabel.embeddings.sentence_transformers import SentenceTransformerEmbeddings
+ from distilabel.models import SentenceTransformerEmbeddings
from distilabel.pipeline import Pipeline
from distilabel.steps import EmbeddingGeneration, FaissNearestNeighbour, LoadDataFromHub
diff --git a/src/distilabel/steps/reward_model.py b/src/distilabel/steps/reward_model.py
index 49ddc065df..fcb5b27371 100644
--- a/src/distilabel/steps/reward_model.py
+++ b/src/distilabel/steps/reward_model.py
@@ -17,7 +17,7 @@
from pydantic import Field, PrivateAttr, SecretStr
-from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
+from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.steps.base import Step, StepInput
from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR
diff --git a/src/distilabel/steps/tasks/__init__.py b/src/distilabel/steps/tasks/__init__.py
index 725fd065fd..98974b00db 100644
--- a/src/distilabel/steps/tasks/__init__.py
+++ b/src/distilabel/steps/tasks/__init__.py
@@ -19,6 +19,7 @@
from distilabel.steps.tasks.base import GeneratorTask, Task
from distilabel.steps.tasks.clair import CLAIR
from distilabel.steps.tasks.complexity_scorer import ComplexityScorer
+from distilabel.steps.tasks.decorator import task
from distilabel.steps.tasks.evol_instruct.base import EvolInstruct
from distilabel.steps.tasks.evol_instruct.evol_complexity.base import EvolComplexity
from distilabel.steps.tasks.evol_instruct.evol_complexity.generator import (
@@ -62,6 +63,7 @@
"APIGenGenerator",
"APIGenSemanticChecker",
"ComplexityScorer",
+ "task",
"EvolInstruct",
"EvolComplexity",
"EvolComplexityGenerator",
diff --git a/src/distilabel/steps/tasks/apigen/generator.py b/src/distilabel/steps/tasks/apigen/generator.py
index c1c691e378..39f202d065 100644
--- a/src/distilabel/steps/tasks/apigen/generator.py
+++ b/src/distilabel/steps/tasks/apigen/generator.py
@@ -88,7 +88,7 @@ class APIGenGenerator(Task):
```python
from distilabel.steps.tasks import ApiGenGenerator
- from distilabel.llms import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
@@ -138,7 +138,7 @@ class APIGenGenerator(Task):
```python
from distilabel.steps.tasks import ApiGenGenerator
- from distilabel.llms import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
diff --git a/src/distilabel/steps/tasks/apigen/semantic_checker.py b/src/distilabel/steps/tasks/apigen/semantic_checker.py
index 5ec7cdc57d..c5cf0b183b 100644
--- a/src/distilabel/steps/tasks/apigen/semantic_checker.py
+++ b/src/distilabel/steps/tasks/apigen/semantic_checker.py
@@ -80,7 +80,7 @@ class APIGenSemanticChecker(Task):
```python
from distilabel.steps.tasks import APIGenSemanticChecker
- from distilabel.llms import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
@@ -125,7 +125,7 @@ class APIGenSemanticChecker(Task):
```python
from distilabel.steps.tasks import APIGenSemanticChecker
- from distilabel.llms import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
diff --git a/src/distilabel/steps/tasks/argilla_labeller.py b/src/distilabel/steps/tasks/argilla_labeller.py
index d0874ed3de..1888087e8d 100644
--- a/src/distilabel/steps/tasks/argilla_labeller.py
+++ b/src/distilabel/steps/tasks/argilla_labeller.py
@@ -81,7 +81,7 @@ class ArgillaLabeller(Task):
import argilla as rg
from argilla import Suggestion
from distilabel.steps.tasks import ArgillaLabeller
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# Get information from Argilla dataset definition
dataset = rg.Dataset("my_dataset")
@@ -138,7 +138,7 @@ class ArgillaLabeller(Task):
```python
import argilla as rg
from distilabel.steps.tasks import ArgillaLabeller
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# Get information from Argilla dataset definition
dataset = rg.Dataset("my_dataset")
@@ -186,7 +186,7 @@ class ArgillaLabeller(Task):
```python
import argilla as rg
from distilabel.steps.tasks import ArgillaLabeller
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# Overwrite default prompts and instructions
labeller = ArgillaLabeller(
@@ -208,17 +208,13 @@ class ArgillaLabeller(Task):
system_prompt: str = (
"You are an expert annotator and labelling assistant that understands complex domains and natural language processing. "
"You are given input fields and a question. "
- "You should create a valid JSON object as an answer to the question based on the input fields. "
- "1. Understand the input fields and optional guidelines. "
- "2. Understand the question type and the question settings. "
- "3. Reason through your response step-by-step. "
- "4. Provide a valid JSON object as an answer to the question."
+ "You should create a valid JSON object as an response to the question based on the input fields. "
)
question_to_label_instruction: Dict[str, str] = {
- "label_selection": "Select the appropriate label from the list of provided labels.",
- "multi_label_selection": "Select none, one or multiple labels from the list of provided labels.",
- "text": "Provide a text response to the question.",
- "rating": "Provide a rating for the question.",
+ "label_selection": "Select the appropriate label for the fields from the list of optional labels.",
+ "multi_label_selection": "Select none, one or multiple labels for the fields from the list of optional labels.",
+ "text": "Provide a response to the question based on the fields.",
+ "rating": "Provide a rating for the question based on the fields.",
}
example_records: Optional[
RuntimeParameter[Union[List[Union[Dict[str, Any], BaseModel]], None]]
@@ -290,12 +286,8 @@ def _format_record(
"""
output = []
for field in fields:
- if title := field.get("title"):
- output.append(f"title: {title}")
- if description := field.get("description"):
- output.append(f"description: {description}")
output.append(record.get("fields", {}).get(field.get("name", "")))
- return "\n".join(output)
+ return "fields: " + "\n".join(output)
def _get_label_instruction(self, question: Dict[str, Any]) -> str:
"""Get the label instruction for the question.
@@ -318,15 +310,11 @@ def _format_question(self, question: Dict[str, Any]) -> str:
Returns:
str: The formatted question.
"""
- output = [
- f"title: {question.get('title', '')}",
- f"description: {question.get('description', '')}",
- f"label_instruction: {self._get_label_instruction(question)}",
- ]
- settings = question.get("settings", {})
- if "options" in settings:
+ output = []
+ output.append(f"question: {self._get_label_instruction(question)}")
+ if "options" in question.get("settings", {}):
output.append(
- f"labels: {[option['value'] for option in settings.get('options', [])]}"
+ f"optional labels: {[option['value'] for option in question.get('settings', {}).get('options', [])]}"
)
return "\n".join(output)
@@ -355,7 +343,7 @@ def _format_example_records(
formatted_value = self._assign_value_to_question_value_model(
value, question
)
- base.append(f"Response: {formatted_value}")
+ base.append(f"response: {formatted_value}")
base.append("")
else:
warnings.warn(
diff --git a/src/distilabel/steps/tasks/base.py b/src/distilabel/steps/tasks/base.py
index 0524749e26..ee2dae790d 100644
--- a/src/distilabel/steps/tasks/base.py
+++ b/src/distilabel/steps/tasks/base.py
@@ -21,8 +21,8 @@
from distilabel.constants import DISTILABEL_METADATA_KEY
from distilabel.errors import DistilabelUserError
-from distilabel.llms.base import LLM
from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.models.llms.base import LLM
from distilabel.steps.base import (
GeneratorStep,
GlobalStep,
@@ -33,7 +33,7 @@
from distilabel.utils.dicts import group_dicts
if TYPE_CHECKING:
- from distilabel.llms.typing import GenerateOutput
+ from distilabel.models.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import ChatType, FormattedInput
from distilabel.steps.typing import StepOutput
@@ -245,8 +245,8 @@ def _set_default_structured_output(self) -> None:
if self.use_default_structured_output and not self.llm.structured_output:
# In case the default structured output is required, we have to set it before
# the LLM is loaded
- from distilabel.llms import InferenceEndpointsLLM
- from distilabel.llms.base import AsyncLLM
+ from distilabel.models.llms import InferenceEndpointsLLM
+ from distilabel.models.llms.base import AsyncLLM
def check_dependency(module_name: str) -> None:
if not importlib.util.find_spec(module_name):
@@ -301,7 +301,7 @@ def print(self, sample_input: Optional["ChatType"] = None) -> None:
```python
from distilabel.steps.tasks import URIAL
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models.llms.huggingface import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
urial = URIAL(
diff --git a/src/distilabel/steps/tasks/clair.py b/src/distilabel/steps/tasks/clair.py
index cbf189ab72..524a1d76c9 100644
--- a/src/distilabel/steps/tasks/clair.py
+++ b/src/distilabel/steps/tasks/clair.py
@@ -58,7 +58,7 @@ class CLAIR(Task):
```python
from distilabel.steps.tasks import CLAIR
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
diff --git a/src/distilabel/steps/tasks/complexity_scorer.py b/src/distilabel/steps/tasks/complexity_scorer.py
index 401e3b760f..7578ecf187 100644
--- a/src/distilabel/steps/tasks/complexity_scorer.py
+++ b/src/distilabel/steps/tasks/complexity_scorer.py
@@ -67,7 +67,7 @@ class ComplexityScorer(Task):
```python
from distilabel.steps.tasks import ComplexityScorer
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
scorer = ComplexityScorer(
@@ -91,7 +91,7 @@ class ComplexityScorer(Task):
```python
from distilabel.steps.tasks import ComplexityScorer
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
scorer = ComplexityScorer(
diff --git a/src/distilabel/steps/tasks/decorator.py b/src/distilabel/steps/tasks/decorator.py
new file mode 100644
index 0000000000..8862734f8c
--- /dev/null
+++ b/src/distilabel/steps/tasks/decorator.py
@@ -0,0 +1,220 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+import re
+from typing import TYPE_CHECKING, Any, Callable, Dict, Final, List, Tuple, Type, Union
+
+import yaml
+
+from distilabel.errors import DistilabelUserError
+from distilabel.steps.tasks.base import Task
+from distilabel.steps.tasks.typing import FormattedInput
+
+if TYPE_CHECKING:
+ from distilabel.steps.typing import StepColumns
+
+
+TaskFormattingOutputFunc = Callable[..., Dict[str, Any]]
+
+
+def task(
+ inputs: Union["StepColumns", None] = None,
+ outputs: Union["StepColumns", None] = None,
+) -> Callable[..., Type["Task"]]:
+ """Creates a `Task` from a formatting output function.
+
+ Args:
+ inputs: a list containing the name of the inputs columns/keys or a dictionary
+ where the keys are the columns and the values are booleans indicating whether
+ the column is required or not, that are required by the step. If not provided
+ the default will be an empty list `[]` and it will be assumed that the step
+ doesn't need any specific columns. Defaults to `None`.
+ outputs: a list containing the name of the outputs columns/keys or a dictionary
+ where the keys are the columns and the values are booleans indicating whether
+ the column will be generated or not. If not provided the default will be an
+ empty list `[]` and it will be assumed that the step doesn't need any specific
+ columns. Defaults to `None`.
+ """
+
+ inputs = inputs or []
+ outputs = outputs or []
+
+ def decorator(func: TaskFormattingOutputFunc) -> Type["Task"]:
+ doc = inspect.getdoc(func)
+ if doc is None:
+ raise DistilabelUserError(
+ "When using the `task` decorator, including a docstring in the formatting"
+ " function is mandatory. The docstring must follow the format described"
+ " in the documentation.",
+ page="",
+ )
+
+ system_prompt, user_message_template = _parse_docstring(doc)
+ _validate_templates(inputs, system_prompt, user_message_template)
+
+ def inputs_property(self) -> "StepColumns":
+ return inputs
+
+ def outputs_property(self) -> "StepColumns":
+ return outputs
+
+ def format_input(self, input: Dict[str, Any]) -> "FormattedInput":
+ return [
+ {"role": "system", "content": system_prompt.format(**input)},
+ {"role": "user", "content": user_message_template.format(**input)},
+ ]
+
+ def format_output(
+ self, output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ return func(output, input)
+
+ return type(
+ func.__name__,
+ (Task,),
+ {
+ "inputs": property(inputs_property),
+ "outputs": property(outputs_property),
+ "__module__": func.__module__,
+ "format_input": format_input,
+ "format_output": format_output,
+ },
+ )
+
+ return decorator
+
+
+_SYSTEM_PROMPT_YAML_KEY: Final[str] = "system_prompt"
+_USER_MESSAGE_TEMPLATE_YAML_KEY: Final[str] = "user_message_template"
+_DOCSTRING_FORMATTING_FUNCTION_ERROR: Final[str] = (
+ "Formatting function decorated with `task` doesn't follow the expected format. Please,"
+ " check the documentation and update the function to include a docstring with the expected"
+ " format."
+)
+
+
+def _parse_docstring(docstring: str) -> Tuple[str, str]:
+ """Parses the docstring of the formatting function that was built using the `task`
+ decorator.
+
+ Args:
+ docstring: the docstring of the formatting function.
+
+ Returns:
+ A tuple containing the system prompt and the user message template.
+
+ Raises:
+ DistilabelUserError: if the docstring doesn't follow the expected format or if
+ the expected keys are missing.
+ """
+ parts = docstring.split("---")
+
+ if len(parts) != 3:
+ raise DistilabelUserError(
+ _DOCSTRING_FORMATTING_FUNCTION_ERROR,
+ page="",
+ )
+
+ yaml_content = parts[1]
+
+ try:
+ parsed_yaml = yaml.safe_load(yaml_content)
+ if not isinstance(parsed_yaml, dict):
+ raise DistilabelUserError(
+ _DOCSTRING_FORMATTING_FUNCTION_ERROR,
+ page="",
+ )
+
+ system_prompt = parsed_yaml.get(_SYSTEM_PROMPT_YAML_KEY)
+ user_template = parsed_yaml.get(_USER_MESSAGE_TEMPLATE_YAML_KEY)
+ if system_prompt is None or user_template is None:
+ raise DistilabelUserError(
+ "The formatting function decorated with `task` must include both the `system_prompt`"
+ " and `user_message_template` keys in the docstring. Please, check the documentation"
+ " and update the docstring of the formatting function to include the expected"
+ " keys.",
+ page="",
+ )
+
+ return system_prompt.strip(), user_template.strip()
+
+ except yaml.YAMLError as e:
+ raise DistilabelUserError(_DOCSTRING_FORMATTING_FUNCTION_ERROR, page="") from e
+
+
+TEMPLATE_PLACEHOLDERS_REGEX = re.compile(r"\{(\w+)\}")
+
+
+def _validate_templates(
+ inputs: "StepColumns", system_prompt: str, user_message_template: str
+) -> None:
+ """Validates the system prompt and user message template to ensure that they only
+ contain the allowed placeholders i.e. the columns/keys that are provided as inputs.
+
+ Args:
+ inputs: the list of inputs columns/keys.
+ system_prompt: the system prompt.
+ user_message_template: the user message template.
+
+ Raises:
+ DistilabelUserError: if the system prompt or the user message template contain
+ invalid placeholders.
+ """
+ list_inputs = list(inputs.keys()) if isinstance(inputs, dict) else inputs
+
+ valid_system_prompt, invalid_system_prompt_placeholders = _validate_template(
+ system_prompt, list_inputs
+ )
+ if not valid_system_prompt:
+ raise DistilabelUserError(
+ f"The formatting function decorated with `task` includes invalid placeholders"
+ f" in the extracted `system_prompt` from the function docstring. Valid placeholders"
+ f" are: {list_inputs}, but the following placeholders were found: {invalid_system_prompt_placeholders}."
+ f" Please, update the `system_prompt` to only include the valid placeholders.",
+ page="",
+ )
+
+ valid_user_message_template, invalid_user_message_template_placeholders = (
+ _validate_template(user_message_template, list_inputs)
+ )
+ if not valid_user_message_template:
+ raise DistilabelUserError(
+ f"The formatting function decorated with `task` includes invalid placeholders"
+ f" in the extracted `user_message_template` from the function docstring. Valid"
+ f" placeholders are: {list_inputs}, but the following placeholders were found:"
+ f" {invalid_user_message_template_placeholders}. Please, update the `system_prompt`"
+ " to only include the valid placeholders.",
+ page="",
+ )
+
+
+def _validate_template(
+ template: str, allowed_placeholders: List[str]
+) -> Tuple[bool, set[str]]:
+ """Validates that the template only contains the allowed placeholders.
+
+ Args:
+ template: the template to validate.
+ allowed_placeholders: the list of allowed placeholders.
+
+ Returns:
+ A tuple containing a boolean indicating if the template is valid and a set
+ with the invalid placeholders.
+ """
+ placeholders = set(TEMPLATE_PLACEHOLDERS_REGEX.findall(template))
+ allowed_placeholders_set = set(allowed_placeholders)
+ are_valid = placeholders.issubset(allowed_placeholders_set)
+ invalid_placeholders = placeholders - allowed_placeholders_set
+ return are_valid, invalid_placeholders
diff --git a/src/distilabel/steps/tasks/evol_instruct/base.py b/src/distilabel/steps/tasks/evol_instruct/base.py
index 95f271a117..9bbf0de34b 100644
--- a/src/distilabel/steps/tasks/evol_instruct/base.py
+++ b/src/distilabel/steps/tasks/evol_instruct/base.py
@@ -75,7 +75,7 @@ class EvolInstruct(Task):
```python
from distilabel.steps.tasks import EvolInstruct
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
evol_instruct = EvolInstruct(
@@ -96,7 +96,7 @@ class EvolInstruct(Task):
```python
from distilabel.steps.tasks import EvolInstruct
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
evol_instruct = EvolInstruct(
@@ -124,7 +124,7 @@ class EvolInstruct(Task):
```python
from distilabel.steps.tasks import EvolInstruct
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
evol_instruct = EvolInstruct(
diff --git a/src/distilabel/steps/tasks/evol_instruct/evol_complexity/base.py b/src/distilabel/steps/tasks/evol_instruct/evol_complexity/base.py
index a7e46b154b..ce9a404aa0 100644
--- a/src/distilabel/steps/tasks/evol_instruct/evol_complexity/base.py
+++ b/src/distilabel/steps/tasks/evol_instruct/evol_complexity/base.py
@@ -67,7 +67,7 @@ class EvolComplexity(EvolInstruct):
```python
from distilabel.steps.tasks import EvolComplexity
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
evol_complexity = EvolComplexity(
diff --git a/src/distilabel/steps/tasks/evol_instruct/evol_complexity/generator.py b/src/distilabel/steps/tasks/evol_instruct/evol_complexity/generator.py
index f1965d9e83..a1b6c83f78 100644
--- a/src/distilabel/steps/tasks/evol_instruct/evol_complexity/generator.py
+++ b/src/distilabel/steps/tasks/evol_instruct/evol_complexity/generator.py
@@ -65,7 +65,7 @@ class EvolComplexityGenerator(EvolInstructGenerator):
```python
from distilabel.steps.tasks import EvolComplexityGenerator
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
evol_complexity_generator = EvolComplexityGenerator(
diff --git a/src/distilabel/steps/tasks/evol_instruct/generator.py b/src/distilabel/steps/tasks/evol_instruct/generator.py
index 1f56c866a3..335e9844f0 100644
--- a/src/distilabel/steps/tasks/evol_instruct/generator.py
+++ b/src/distilabel/steps/tasks/evol_instruct/generator.py
@@ -81,7 +81,7 @@ class EvolInstructGenerator(GeneratorTask):
```python
from distilabel.steps.tasks import EvolInstructGenerator
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
evol_instruct_generator = EvolInstructGenerator(
diff --git a/src/distilabel/steps/tasks/evol_quality/base.py b/src/distilabel/steps/tasks/evol_quality/base.py
index 5c899aa680..b7d2690c35 100644
--- a/src/distilabel/steps/tasks/evol_quality/base.py
+++ b/src/distilabel/steps/tasks/evol_quality/base.py
@@ -71,7 +71,7 @@ class EvolQuality(Task):
```python
from distilabel.steps.tasks import EvolQuality
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
evol_quality = EvolQuality(
diff --git a/src/distilabel/steps/tasks/generate_embeddings.py b/src/distilabel/steps/tasks/generate_embeddings.py
index 85db623d94..f73ee1b2b3 100644
--- a/src/distilabel/steps/tasks/generate_embeddings.py
+++ b/src/distilabel/steps/tasks/generate_embeddings.py
@@ -15,7 +15,7 @@
from typing import TYPE_CHECKING, Any, Dict
from distilabel.errors import DistilabelUserError
-from distilabel.llms.base import LLM
+from distilabel.models.llms.base import LLM
from distilabel.steps.base import Step, StepInput
from distilabel.utils.chat import is_openai_format
@@ -54,7 +54,7 @@ class GenerateEmbeddings(Step):
```python
from distilabel.steps.tasks import GenerateEmbeddings
- from distilabel.llms.huggingface import TransformersLLM
+ from distilabel.models.llms.huggingface import TransformersLLM
# Consider this as a placeholder for your actual LLM.
embedder = GenerateEmbeddings(
diff --git a/src/distilabel/steps/tasks/genstruct.py b/src/distilabel/steps/tasks/genstruct.py
index 02a0657339..2b9c307d5b 100644
--- a/src/distilabel/steps/tasks/genstruct.py
+++ b/src/distilabel/steps/tasks/genstruct.py
@@ -73,7 +73,7 @@ class Genstruct(Task):
```python
from distilabel.steps.tasks import Genstruct
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
genstruct = Genstruct(
diff --git a/src/distilabel/steps/tasks/magpie/base.py b/src/distilabel/steps/tasks/magpie/base.py
index a137d931dd..5135e13ae0 100644
--- a/src/distilabel/steps/tasks/magpie/base.py
+++ b/src/distilabel/steps/tasks/magpie/base.py
@@ -19,12 +19,12 @@
from pydantic import Field, PositiveInt, field_validator
from distilabel.errors import DistilabelUserError
-from distilabel.llms.base import LLM
-from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
from distilabel.mixins.runtime_parameters import (
RuntimeParameter,
RuntimeParametersMixin,
)
+from distilabel.models.llms.base import LLM
+from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.steps.base import StepInput
from distilabel.steps.tasks.base import Task
@@ -404,7 +404,7 @@ class Magpie(Task, MagpieBase):
Generating instructions with Llama 3 8B Instruct and TransformersLLM:
```python
- from distilabel.llms import TransformersLLM
+ from distilabel.models import TransformersLLM
from distilabel.steps.tasks import Magpie
magpie = Magpie(
@@ -443,7 +443,7 @@ class Magpie(Task, MagpieBase):
Generating conversations with Llama 3 8B Instruct and TransformersLLM:
```python
- from distilabel.llms import TransformersLLM
+ from distilabel.models import TransformersLLM
from distilabel.steps.tasks import Magpie
magpie = Magpie(
diff --git a/src/distilabel/steps/tasks/magpie/generator.py b/src/distilabel/steps/tasks/magpie/generator.py
index c1e413d32c..c9d18d9fca 100644
--- a/src/distilabel/steps/tasks/magpie/generator.py
+++ b/src/distilabel/steps/tasks/magpie/generator.py
@@ -18,8 +18,8 @@
from typing_extensions import override
from distilabel.errors import DistilabelUserError
-from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.steps.tasks.base import GeneratorTask
from distilabel.steps.tasks.magpie.base import MagpieBase
@@ -98,7 +98,7 @@ class MagpieGenerator(GeneratorTask, MagpieBase):
Generating instructions with Llama 3 8B Instruct and TransformersLLM:
```python
- from distilabel.llms import TransformersLLM
+ from distilabel.models import TransformersLLM
from distilabel.steps.tasks import MagpieGenerator
generator = MagpieGenerator(
@@ -130,7 +130,7 @@ class MagpieGenerator(GeneratorTask, MagpieBase):
Generating a conversation with Llama 3 8B Instruct and TransformersLLM:
```python
- from distilabel.llms import TransformersLLM
+ from distilabel.models import TransformersLLM
from distilabel.steps.tasks import MagpieGenerator
generator = MagpieGenerator(
@@ -210,7 +210,7 @@ class MagpieGenerator(GeneratorTask, MagpieBase):
Generating with system prompts with probabilities:
```python
- from distilabel.llms import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
from distilabel.steps.tasks import MagpieGenerator
magpie = MagpieGenerator(
diff --git a/src/distilabel/steps/tasks/prometheus_eval.py b/src/distilabel/steps/tasks/prometheus_eval.py
index 27cd9622ea..4c61c416be 100644
--- a/src/distilabel/steps/tasks/prometheus_eval.py
+++ b/src/distilabel/steps/tasks/prometheus_eval.py
@@ -138,7 +138,7 @@ class PrometheusEval(Task):
```python
from distilabel.steps.tasks import PrometheusEval
- from distilabel.llms import vLLM
+ from distilabel.models import vLLM
# Consider this as a placeholder for your actual LLM.
prometheus = PrometheusEval(
@@ -175,7 +175,7 @@ class PrometheusEval(Task):
```python
from distilabel.steps.tasks import PrometheusEval
- from distilabel.llms import vLLM
+ from distilabel.models import vLLM
# Consider this as a placeholder for your actual LLM.
prometheus = PrometheusEval(
@@ -212,7 +212,7 @@ class PrometheusEval(Task):
```python
from distilabel.steps.tasks import PrometheusEval
- from distilabel.llms import vLLM
+ from distilabel.models import vLLM
# Consider this as a placeholder for your actual LLM.
prometheus = PrometheusEval(
@@ -252,7 +252,7 @@ class PrometheusEval(Task):
```python
from distilabel.steps.tasks import PrometheusEval
- from distilabel.llms import vLLM
+ from distilabel.models import vLLM
# Consider this as a placeholder for your actual LLM.
prometheus = PrometheusEval(
diff --git a/src/distilabel/steps/tasks/quality_scorer.py b/src/distilabel/steps/tasks/quality_scorer.py
index 604f2a0276..efafda2b7a 100644
--- a/src/distilabel/steps/tasks/quality_scorer.py
+++ b/src/distilabel/steps/tasks/quality_scorer.py
@@ -67,7 +67,7 @@ class QualityScorer(Task):
```python
from distilabel.steps.tasks import QualityScorer
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
scorer = QualityScorer(
@@ -102,7 +102,7 @@ class QualityScorer(Task):
```python
from distilabel.steps.tasks import QualityScorer
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
scorer = QualityScorer(
llm=InferenceEndpointsLLM(
diff --git a/src/distilabel/steps/tasks/self_instruct.py b/src/distilabel/steps/tasks/self_instruct.py
index 28ac346c39..dcca46ee67 100644
--- a/src/distilabel/steps/tasks/self_instruct.py
+++ b/src/distilabel/steps/tasks/self_instruct.py
@@ -66,7 +66,7 @@ class SelfInstruct(Task):
```python
from distilabel.steps.tasks import SelfInstruct
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
self_instruct = SelfInstruct(
llm=InferenceEndpointsLLM(
diff --git a/src/distilabel/steps/tasks/sentence_transformers.py b/src/distilabel/steps/tasks/sentence_transformers.py
index f33a223c63..fa29bbe367 100644
--- a/src/distilabel/steps/tasks/sentence_transformers.py
+++ b/src/distilabel/steps/tasks/sentence_transformers.py
@@ -108,7 +108,7 @@ class GenerateSentencePair(Task):
```python
from distilabel.steps.tasks import GenerateSentencePair
- from distilabel.llms import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
generate_sentence_pair = GenerateSentencePair(
triplet=True, # `False` to generate only positive
@@ -128,7 +128,7 @@ class GenerateSentencePair(Task):
Generating semantically similar sentences:
```python
- from distilabel.llms import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
from distilabel.steps.tasks import GenerateSentencePair
generate_sentence_pair = GenerateSentencePair(
@@ -150,7 +150,7 @@ class GenerateSentencePair(Task):
```python
from distilabel.steps.tasks import GenerateSentencePair
- from distilabel.llms import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
generate_sentence_pair = GenerateSentencePair(
triplet=True, # `False` to generate only positive
@@ -171,7 +171,7 @@ class GenerateSentencePair(Task):
```python
from distilabel.steps.tasks import GenerateSentencePair
- from distilabel.llms import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
generate_sentence_pair = GenerateSentencePair(
triplet=True, # `False` to generate only positive
@@ -192,7 +192,7 @@ class GenerateSentencePair(Task):
```python
from distilabel.steps.tasks import GenerateSentencePair
- from distilabel.llms import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
generate_sentence_pair = GenerateSentencePair(
triplet=True, # `False` to generate only positive
@@ -214,7 +214,7 @@ class GenerateSentencePair(Task):
```python
from distilabel.steps.tasks import GenerateSentencePair
- from distilabel.llms import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
generate_sentence_pair = GenerateSentencePair(
triplet=True, # `False` to generate only positive
@@ -237,7 +237,7 @@ class GenerateSentencePair(Task):
```python
from distilabel.steps.tasks import GenerateSentencePair
- from distilabel.llms import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
generate_sentence_pair = GenerateSentencePair(
triplet=True, # `False` to generate only positive
diff --git a/src/distilabel/steps/tasks/structured_generation.py b/src/distilabel/steps/tasks/structured_generation.py
index 81ee74bd85..905a6672d0 100644
--- a/src/distilabel/steps/tasks/structured_generation.py
+++ b/src/distilabel/steps/tasks/structured_generation.py
@@ -52,7 +52,7 @@ class StructuredGeneration(Task):
```python
from distilabel.steps.tasks import StructuredGeneration
- from distilabel.llms import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
structured_gen = StructuredGeneration(
llm=InferenceEndpointsLLM(
@@ -109,7 +109,7 @@ class StructuredGeneration(Task):
```python
from distilabel.steps.tasks import StructuredGeneration
- from distilabel.llms import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
structured_gen = StructuredGeneration(
llm=InferenceEndpointsLLM(
diff --git a/src/distilabel/steps/tasks/templates/argillalabeller.jinja2 b/src/distilabel/steps/tasks/templates/argillalabeller.jinja2
index d5afa75d27..2b5839384b 100644
--- a/src/distilabel/steps/tasks/templates/argillalabeller.jinja2
+++ b/src/distilabel/steps/tasks/templates/argillalabeller.jinja2
@@ -1,13 +1,13 @@
Please provide an answer to the question based on the input fields{% if examples %} and examples{% endif %}.
{% if guidelines %}
# Guidelines
-{{ guidelines }}
-{% endif %}
-# Input Fields
-{{ fields }}
-# Question
-{{ question }}
+{{ guidelines }}{% endif %}
{% if examples %}
# Examples
-{{ examples }}
-{% endif %}
\ No newline at end of file
+{{ examples }}{% endif %}
+# Question
+{{ question }}
+
+# Fields
+{{ fields }}
+response:
\ No newline at end of file
diff --git a/src/distilabel/steps/tasks/text_classification.py b/src/distilabel/steps/tasks/text_classification.py
index 5d04b3b2db..19df530fb6 100644
--- a/src/distilabel/steps/tasks/text_classification.py
+++ b/src/distilabel/steps/tasks/text_classification.py
@@ -90,7 +90,7 @@ class TextClassification(Task):
```python
from distilabel.steps.tasks import TextClassification
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
llm = InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
diff --git a/src/distilabel/steps/tasks/text_generation.py b/src/distilabel/steps/tasks/text_generation.py
index a8b2048e54..daabe5525b 100644
--- a/src/distilabel/steps/tasks/text_generation.py
+++ b/src/distilabel/steps/tasks/text_generation.py
@@ -69,7 +69,7 @@ class TextGeneration(Task):
```python
from distilabel.steps.tasks import TextGeneration
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
text_gen = TextGeneration(
@@ -99,7 +99,7 @@ class TextGeneration(Task):
```python
from distilabel.steps.tasks import TextGeneration
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
CUSTOM_TEMPLATE = '''Document:
{{ document }}
@@ -145,7 +145,7 @@ class TextGeneration(Task):
```python
from distilabel.steps.tasks import TextGeneration
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
CUSTOM_TEMPLATE = '''Generate a clear, single-sentence instruction based on the following examples:
@@ -325,7 +325,7 @@ class ChatGeneration(Task):
```python
from distilabel.steps.tasks import ChatGeneration
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
chat = ChatGeneration(
diff --git a/src/distilabel/steps/tasks/ultrafeedback.py b/src/distilabel/steps/tasks/ultrafeedback.py
index aeb57bda36..bac144f54d 100644
--- a/src/distilabel/steps/tasks/ultrafeedback.py
+++ b/src/distilabel/steps/tasks/ultrafeedback.py
@@ -63,7 +63,7 @@ class UltraFeedback(Task):
```python
from distilabel.steps.tasks import UltraFeedback
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
ultrafeedback = UltraFeedback(
@@ -101,7 +101,7 @@ class UltraFeedback(Task):
```python
from distilabel.steps.tasks import UltraFeedback
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
ultrafeedback = UltraFeedback(
@@ -137,7 +137,7 @@ class UltraFeedback(Task):
```python
from distilabel.steps.tasks import UltraFeedback
- from distilabel.llms.huggingface import InferenceEndpointsLLM
+ from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
ultrafeedback = UltraFeedback(
diff --git a/src/distilabel/steps/tasks/urial.py b/src/distilabel/steps/tasks/urial.py
index 705b9c4883..24b643ada6 100644
--- a/src/distilabel/steps/tasks/urial.py
+++ b/src/distilabel/steps/tasks/urial.py
@@ -50,7 +50,7 @@ class URIAL(Task):
Generate text from an instruction:
```python
- from distilabel.llms import vLLM
+ from distilabel.models import vLLM
from distilabel.steps.tasks import URIAL
step = URIAL(
diff --git a/src/distilabel/typing.py b/src/distilabel/typing.py
new file mode 100644
index 0000000000..28bfd57fc5
--- /dev/null
+++ b/src/distilabel/typing.py
@@ -0,0 +1,55 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from distilabel.models.llms.typing import GenerateOutput
+from distilabel.pipeline.typing import (
+ DownstreamConnectable,
+ DownstreamConnectableSteps,
+ InputDataset,
+ PipelineRuntimeParametersInfo,
+ StepLoadStatus,
+ UpstreamConnectableSteps,
+)
+from distilabel.steps.tasks.typing import (
+ ChatItem,
+ ChatType,
+ FormattedInput,
+ InstructorStructuredOutputType,
+ OutlinesStructuredOutputType,
+ StandardInput,
+ StructuredInput,
+ StructuredOutputType,
+)
+from distilabel.steps.typing import GeneratorStepOutput, StepColumns, StepOutput
+
+__all__ = [
+ "GenerateOutput",
+ "DownstreamConnectable",
+ "DownstreamConnectableSteps",
+ "InputDataset",
+ "PipelineRuntimeParametersInfo",
+ "StepLoadStatus",
+ "UpstreamConnectableSteps",
+ "ChatItem",
+ "ChatType",
+ "FormattedInput",
+ "InstructorStructuredOutputType",
+ "OutlinesStructuredOutputType",
+ "StandardInput",
+ "StructuredInput",
+ "StructuredOutputType",
+ "GeneratorStepOutput",
+ "StepColumns",
+ "StepOutput",
+]
diff --git a/src/distilabel/utils/export_components_info.py b/src/distilabel/utils/export_components_info.py
index fa1cd6556d..00144fd041 100644
--- a/src/distilabel/utils/export_components_info.py
+++ b/src/distilabel/utils/export_components_info.py
@@ -15,8 +15,8 @@
import inspect
from typing import Generator, List, Type, TypedDict, TypeVar
-from distilabel.embeddings.base import Embeddings
-from distilabel.llms.base import LLM
+from distilabel.models.embeddings.base import Embeddings
+from distilabel.models.llms.base import LLM
from distilabel.steps.base import _Step
from distilabel.steps.tasks.base import _Task
from distilabel.steps.tasks.generate_embeddings import GenerateEmbeddings
diff --git a/tests/integration/test_generator_and_sampler.py b/tests/integration/test_generator_and_sampler.py
index 1bb0a457b5..cdbeb5703a 100644
--- a/tests/integration/test_generator_and_sampler.py
+++ b/tests/integration/test_generator_and_sampler.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from distilabel.llms._dummy import DummyAsyncLLM
+from distilabel.models.llms._dummy import DummyAsyncLLM
from distilabel.pipeline import Pipeline
from distilabel.steps import CombineOutputs, LoadDataFromDicts
from distilabel.steps.generators.data_sampler import DataSampler
diff --git a/tests/integration/test_offline_batch_generation.py b/tests/integration/test_offline_batch_generation.py
index a9fe880ff7..e3dea4af56 100644
--- a/tests/integration/test_offline_batch_generation.py
+++ b/tests/integration/test_offline_batch_generation.py
@@ -16,13 +16,13 @@
from typing import TYPE_CHECKING, Any, List, Union
from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException
-from distilabel.llms import LLM
+from distilabel.models.llms import LLM
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts
from distilabel.steps.tasks import TextGeneration
if TYPE_CHECKING:
- from distilabel.llms.typing import GenerateOutput
+ from distilabel.models.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import FormattedInput
diff --git a/tests/integration/test_pipe_llms.py b/tests/integration/test_pipe_llms.py
index 47174be117..c95af1ac3f 100644
--- a/tests/integration/test_pipe_llms.py
+++ b/tests/integration/test_pipe_llms.py
@@ -15,9 +15,9 @@
import os
from typing import TYPE_CHECKING, Dict, List
-from distilabel.llms.huggingface.transformers import TransformersLLM
-from distilabel.llms.openai import OpenAILLM
from distilabel.mixins.runtime_parameters import RuntimeParameter
+from distilabel.models.llms.huggingface.transformers import TransformersLLM
+from distilabel.models.llms.openai import OpenAILLM
from distilabel.pipeline.local import Pipeline
from distilabel.steps.base import Step, StepInput
from distilabel.steps.generators.huggingface import LoadDataFromHub
diff --git a/tests/integration/test_prints.py b/tests/integration/test_prints.py
index 7db85caf8f..e7ea68a858 100644
--- a/tests/integration/test_prints.py
+++ b/tests/integration/test_prints.py
@@ -17,7 +17,7 @@
import pytest
-from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
+from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.steps import tasks as tasks_
from tests.unit.conftest import DummyLLM
diff --git a/tests/unit/cli/test_pipeline.yaml b/tests/unit/cli/test_pipeline.yaml
index 3d86f5ab18..07b349334d 100644
--- a/tests/unit/cli/test_pipeline.yaml
+++ b/tests/unit/cli/test_pipeline.yaml
@@ -40,7 +40,7 @@ pipeline:
model: gpt-3.5-turbo
base_url: https://api.openai.com/v1
type_info:
- module: distilabel.llms.openai
+ module: distilabel.models.llms.openai
name: OpenAILLM
group_generations: false
num_generations: 3
@@ -94,7 +94,7 @@ pipeline:
model: gpt-3.5-turbo
base_url: https://api.openai.com/v1
type_info:
- module: distilabel.llms.openai
+ module: distilabel.models.llms.openai
name: OpenAILLM
group_generations: true
num_generations: 3
diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py
index 8c7c240b09..b3ec2de908 100644
--- a/tests/unit/conftest.py
+++ b/tests/unit/conftest.py
@@ -16,12 +16,12 @@
import pytest
-from distilabel.llms.base import LLM, AsyncLLM
-from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
+from distilabel.models.llms.base import LLM, AsyncLLM
+from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.steps.tasks.base import Task
if TYPE_CHECKING:
- from distilabel.llms.typing import GenerateOutput
+ from distilabel.models.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import ChatType, FormattedInput
@@ -54,8 +54,8 @@ def model_name(self) -> str:
def generate( # type: ignore
self, inputs: "FormattedInput", num_generations: int = 1
- ) -> "GenerateOutput":
- return ["output" for _ in range(num_generations)]
+ ) -> List["GenerateOutput"]:
+ return [["output" for _ in range(num_generations)]]
class DummyMagpieLLM(LLM, MagpieChatTemplateMixin):
diff --git a/tests/unit/embeddings/__init__.py b/tests/unit/models/__init__.py
similarity index 100%
rename from tests/unit/embeddings/__init__.py
rename to tests/unit/models/__init__.py
diff --git a/tests/unit/llms/mixins/__init__.py b/tests/unit/models/embeddings/__init__.py
similarity index 100%
rename from tests/unit/llms/mixins/__init__.py
rename to tests/unit/models/embeddings/__init__.py
diff --git a/tests/unit/embeddings/test_sentence_transformers.py b/tests/unit/models/embeddings/test_sentence_transformers.py
similarity index 92%
rename from tests/unit/embeddings/test_sentence_transformers.py
rename to tests/unit/models/embeddings/test_sentence_transformers.py
index 2efeabb807..0291a06263 100644
--- a/tests/unit/embeddings/test_sentence_transformers.py
+++ b/tests/unit/models/embeddings/test_sentence_transformers.py
@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from distilabel.embeddings.sentence_transformers import SentenceTransformerEmbeddings
+from distilabel.models.embeddings.sentence_transformers import (
+ SentenceTransformerEmbeddings,
+)
class TestSentenceTransformersEmbeddings:
diff --git a/tests/unit/embeddings/test_vllm.py b/tests/unit/models/embeddings/test_vllm.py
similarity index 96%
rename from tests/unit/embeddings/test_vllm.py
rename to tests/unit/models/embeddings/test_vllm.py
index 8291f434e9..c98c6088c0 100644
--- a/tests/unit/embeddings/test_vllm.py
+++ b/tests/unit/models/embeddings/test_vllm.py
@@ -14,7 +14,7 @@
from unittest.mock import MagicMock, Mock
-from distilabel.embeddings.vllm import vLLMEmbeddings
+from distilabel.models.embeddings.vllm import vLLMEmbeddings
# @patch("vllm.entrypoints.LLM")
diff --git a/tests/unit/llms/__init__.py b/tests/unit/models/llms/__init__.py
similarity index 100%
rename from tests/unit/llms/__init__.py
rename to tests/unit/models/llms/__init__.py
diff --git a/tests/unit/llms/huggingface/__init__.py b/tests/unit/models/llms/huggingface/__init__.py
similarity index 100%
rename from tests/unit/llms/huggingface/__init__.py
rename to tests/unit/models/llms/huggingface/__init__.py
diff --git a/tests/unit/llms/huggingface/test_inference_endpoints.py b/tests/unit/models/llms/huggingface/test_inference_endpoints.py
similarity index 98%
rename from tests/unit/llms/huggingface/test_inference_endpoints.py
rename to tests/unit/models/llms/huggingface/test_inference_endpoints.py
index d820122a4d..f4054b6736 100644
--- a/tests/unit/llms/huggingface/test_inference_endpoints.py
+++ b/tests/unit/models/llms/huggingface/test_inference_endpoints.py
@@ -27,7 +27,7 @@
ChatCompletionOutputUsage,
)
-from distilabel.llms.huggingface.inference_endpoints import InferenceEndpointsLLM
+from distilabel.models.llms.huggingface.inference_endpoints import InferenceEndpointsLLM
@pytest.fixture(autouse=True)
@@ -315,7 +315,7 @@ def test_serialization(self, mock_inference_client: MagicMock) -> None:
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"type_info": {
- "module": "distilabel.llms.huggingface.inference_endpoints",
+ "module": "distilabel.models.llms.huggingface.inference_endpoints",
"name": "InferenceEndpointsLLM",
},
}
diff --git a/tests/unit/llms/huggingface/test_transformers.py b/tests/unit/models/llms/huggingface/test_transformers.py
similarity index 96%
rename from tests/unit/llms/huggingface/test_transformers.py
rename to tests/unit/models/llms/huggingface/test_transformers.py
index 97214ef5fc..a298ff737e 100644
--- a/tests/unit/llms/huggingface/test_transformers.py
+++ b/tests/unit/models/llms/huggingface/test_transformers.py
@@ -16,7 +16,7 @@
import pytest
-from distilabel.llms.huggingface.transformers import TransformersLLM
+from distilabel.models.llms.huggingface.transformers import TransformersLLM
# load the model just once for all the tests in the module
diff --git a/tests/unit/llms/test_anthropic.py b/tests/unit/models/llms/test_anthropic.py
similarity index 95%
rename from tests/unit/llms/test_anthropic.py
rename to tests/unit/models/llms/test_anthropic.py
index 11fee764c3..3051b99789 100644
--- a/tests/unit/llms/test_anthropic.py
+++ b/tests/unit/models/llms/test_anthropic.py
@@ -20,7 +20,7 @@
import nest_asyncio
import pytest
-from distilabel.llms.anthropic import AnthropicLLM
+from distilabel.models.llms.anthropic import AnthropicLLM
from .utils import DummyUserDetail
@@ -120,7 +120,7 @@ async def test_generate(self, mock_anthropic: MagicMock) -> None:
"timeout": 600.0,
"structured_output": None,
"type_info": {
- "module": "distilabel.llms.anthropic",
+ "module": "distilabel.models.llms.anthropic",
"name": "AnthropicLLM",
},
},
@@ -143,7 +143,7 @@ async def test_generate(self, mock_anthropic: MagicMock) -> None:
"max_retries": 1,
},
"type_info": {
- "module": "distilabel.llms.anthropic",
+ "module": "distilabel.models.llms.anthropic",
"name": "AnthropicLLM",
},
},
@@ -167,7 +167,7 @@ def test_serialization(
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"type_info": {
- "module": "distilabel.llms.anthropic",
+ "module": "distilabel.models.llms.anthropic",
"name": "AnthropicLLM",
},
}
diff --git a/tests/unit/llms/test_anyscale.py b/tests/unit/models/llms/test_anyscale.py
similarity index 94%
rename from tests/unit/llms/test_anyscale.py
rename to tests/unit/models/llms/test_anyscale.py
index 178419c1b7..d12dbebd02 100644
--- a/tests/unit/llms/test_anyscale.py
+++ b/tests/unit/models/llms/test_anyscale.py
@@ -15,7 +15,7 @@
import os
from unittest import mock
-from distilabel.llms.anyscale import AnyscaleLLM
+from distilabel.models.llms.anyscale import AnyscaleLLM
class TestAnyscaleLLM:
@@ -53,7 +53,7 @@ def test_serialization(self) -> None:
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"type_info": {
- "module": "distilabel.llms.anyscale",
+ "module": "distilabel.models.llms.anyscale",
"name": "AnyscaleLLM",
},
}
diff --git a/tests/unit/llms/test_azure.py b/tests/unit/models/llms/test_azure.py
similarity index 87%
rename from tests/unit/llms/test_azure.py
rename to tests/unit/models/llms/test_azure.py
index eee3ed85fb..a2122b611f 100644
--- a/tests/unit/llms/test_azure.py
+++ b/tests/unit/models/llms/test_azure.py
@@ -18,7 +18,7 @@
import pytest
-from distilabel.llms.azure import AzureOpenAILLM
+from distilabel.models.llms.azure import AzureOpenAILLM
from .utils import DummyUserDetail
@@ -43,7 +43,7 @@ def test_azure_openai_llm(self) -> None:
assert llm.api_version == self.api_version
def test_azure_openai_llm_env_vars(self) -> None:
- from distilabel.llms.azure import (
+ from distilabel.models.llms.azure import (
_AZURE_OPENAI_API_KEY_ENV_VAR_NAME,
_AZURE_OPENAI_ENDPOINT_ENV_VAR_NAME,
)
@@ -78,7 +78,7 @@ def test_azure_openai_llm_env_vars(self) -> None:
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"type_info": {
- "module": "distilabel.llms.azure",
+ "module": "distilabel.models.llms.azure",
"name": "AzureOpenAILLM",
},
},
@@ -105,7 +105,7 @@ def test_azure_openai_llm_env_vars(self) -> None:
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"type_info": {
- "module": "distilabel.llms.azure",
+ "module": "distilabel.models.llms.azure",
"name": "AzureOpenAILLM",
},
},
@@ -122,15 +122,5 @@ def test_serialization(
structured_output=structured_output,
)
- # _dump = {
- # "generation_kwargs": {},
- # "model": "gpt-4",
- # "base_url": "https://example-resource.azure.openai.com/",
- # "max_retries": 6,
- # "timeout": 120,
- # "api_version": "preview",
- # "structured_output": None,
- # "type_info": {"module": "distilabel.llms.azure", "name": "AzureOpenAILLM"},
- # }
assert llm.dump() == dump
assert isinstance(AzureOpenAILLM.from_dict(dump), AzureOpenAILLM)
diff --git a/tests/unit/llms/test_base.py b/tests/unit/models/llms/test_base.py
similarity index 100%
rename from tests/unit/llms/test_base.py
rename to tests/unit/models/llms/test_base.py
diff --git a/tests/unit/llms/test_cohere.py b/tests/unit/models/llms/test_cohere.py
similarity index 97%
rename from tests/unit/llms/test_cohere.py
rename to tests/unit/models/llms/test_cohere.py
index 2e398e01cf..4b0a83cbb3 100644
--- a/tests/unit/llms/test_cohere.py
+++ b/tests/unit/models/llms/test_cohere.py
@@ -20,7 +20,7 @@
import nest_asyncio
import pytest
-from distilabel.llms.cohere import CohereLLM
+from distilabel.models.llms.cohere import CohereLLM
from .utils import DummyUserDetail
@@ -145,7 +145,7 @@ async def test_generate(self, mock_async_client: mock.MagicMock) -> None:
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"type_info": {
- "module": "distilabel.llms.cohere",
+ "module": "distilabel.models.llms.cohere",
"name": "CohereLLM",
},
},
@@ -171,7 +171,7 @@ async def test_generate(self, mock_async_client: mock.MagicMock) -> None:
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"type_info": {
- "module": "distilabel.llms.cohere",
+ "module": "distilabel.models.llms.cohere",
"name": "CohereLLM",
},
},
diff --git a/tests/unit/llms/test_groq.py b/tests/unit/models/llms/test_groq.py
similarity index 97%
rename from tests/unit/llms/test_groq.py
rename to tests/unit/models/llms/test_groq.py
index f137750292..ce80c02c8a 100644
--- a/tests/unit/llms/test_groq.py
+++ b/tests/unit/models/llms/test_groq.py
@@ -20,7 +20,7 @@
import nest_asyncio
import pytest
-from distilabel.llms.groq import GroqLLM
+from distilabel.models.llms.groq import GroqLLM
from .utils import DummyUserDetail
@@ -123,7 +123,7 @@ async def test_generate(self, mock_groq: MagicMock) -> None:
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"type_info": {
- "module": "distilabel.llms.groq",
+ "module": "distilabel.models.llms.groq",
"name": "GroqLLM",
},
},
@@ -149,7 +149,7 @@ async def test_generate(self, mock_groq: MagicMock) -> None:
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"type_info": {
- "module": "distilabel.llms.groq",
+ "module": "distilabel.models.llms.groq",
"name": "GroqLLM",
},
},
diff --git a/tests/unit/llms/test_litellm.py b/tests/unit/models/llms/test_litellm.py
similarity index 96%
rename from tests/unit/llms/test_litellm.py
rename to tests/unit/models/llms/test_litellm.py
index 56be99e028..60dfaacbb0 100644
--- a/tests/unit/llms/test_litellm.py
+++ b/tests/unit/models/llms/test_litellm.py
@@ -17,7 +17,7 @@
import nest_asyncio
import pytest
-from distilabel.llms.litellm import LiteLLM
+from distilabel.models.llms.litellm import LiteLLM
@pytest.fixture(params=["mistral/mistral-tiny", "gpt-4"])
@@ -87,7 +87,7 @@ def test_serialization(self, _: MagicMock, model: str) -> None:
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"type_info": {
- "module": "distilabel.llms.litellm",
+ "module": "distilabel.models.llms.litellm",
"name": "LiteLLM",
},
"generation_kwargs": {},
diff --git a/tests/unit/llms/test_llamacpp.py b/tests/unit/models/llms/test_llamacpp.py
similarity index 95%
rename from tests/unit/llms/test_llamacpp.py
rename to tests/unit/models/llms/test_llamacpp.py
index 35c611722d..19cdcd929b 100644
--- a/tests/unit/llms/test_llamacpp.py
+++ b/tests/unit/models/llms/test_llamacpp.py
@@ -18,7 +18,7 @@
import pytest
-from distilabel.llms.llamacpp import LlamaCppLLM
+from distilabel.models.llms.llamacpp import LlamaCppLLM
from .utils import DummyUserDetail
@@ -76,7 +76,7 @@ def test_generate(self, llm: LlamaCppLLM) -> None:
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"type_info": {
- "module": "distilabel.llms.llamacpp",
+ "module": "distilabel.models.llms.llamacpp",
"name": "LlamaCppLLM",
},
"verbose": False,
@@ -103,7 +103,7 @@ def test_generate(self, llm: LlamaCppLLM) -> None:
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"type_info": {
- "module": "distilabel.llms.llamacpp",
+ "module": "distilabel.models.llms.llamacpp",
"name": "LlamaCppLLM",
},
"verbose": False,
diff --git a/tests/unit/llms/test_mistral.py b/tests/unit/models/llms/test_mistral.py
similarity index 96%
rename from tests/unit/llms/test_mistral.py
rename to tests/unit/models/llms/test_mistral.py
index f1b7b4b28f..a0095b3d73 100644
--- a/tests/unit/llms/test_mistral.py
+++ b/tests/unit/models/llms/test_mistral.py
@@ -23,7 +23,7 @@
from .utils import DummyUserDetail
try:
- from distilabel.llms.mistral import MistralLLM
+ from distilabel.models.llms.mistral import MistralLLM
except ImportError:
MistralLLM = None
@@ -132,7 +132,7 @@ async def test_generate(self, mock_mistral: MagicMock) -> None:
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"type_info": {
- "module": "distilabel.llms.mistral",
+ "module": "distilabel.models.llms.mistral",
"name": "MistralLLM",
},
},
@@ -159,7 +159,7 @@ async def test_generate(self, mock_mistral: MagicMock) -> None:
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"type_info": {
- "module": "distilabel.llms.mistral",
+ "module": "distilabel.models.llms.mistral",
"name": "MistralLLM",
},
},
@@ -184,7 +184,7 @@ def test_serialization(
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"type_info": {
- "module": "distilabel.llms.mistral",
+ "module": "distilabel.models.llms.mistral",
"name": "MistralLLM",
},
}
diff --git a/tests/unit/llms/test_moa.py b/tests/unit/models/llms/test_moa.py
similarity index 96%
rename from tests/unit/llms/test_moa.py
rename to tests/unit/models/llms/test_moa.py
index 7efd039b7a..b903f5a980 100644
--- a/tests/unit/llms/test_moa.py
+++ b/tests/unit/models/llms/test_moa.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from distilabel.llms.moa import MOA_SYSTEM_PROMPT, MixtureOfAgentsLLM
+from distilabel.models.llms.moa import MOA_SYSTEM_PROMPT, MixtureOfAgentsLLM
from tests.unit.conftest import DummyAsyncLLM
diff --git a/tests/unit/llms/test_ollama.py b/tests/unit/models/llms/test_ollama.py
similarity index 96%
rename from tests/unit/llms/test_ollama.py
rename to tests/unit/models/llms/test_ollama.py
index db31d9cb07..137ea8adf9 100644
--- a/tests/unit/llms/test_ollama.py
+++ b/tests/unit/models/llms/test_ollama.py
@@ -17,7 +17,7 @@
import nest_asyncio
import pytest
-from distilabel.llms.ollama import OllamaLLM
+from distilabel.models.llms.ollama import OllamaLLM
@patch("ollama.AsyncClient")
@@ -86,7 +86,7 @@ def test_serialization(self, _: MagicMock) -> None:
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"type_info": {
- "module": "distilabel.llms.ollama",
+ "module": "distilabel.models.llms.ollama",
"name": "OllamaLLM",
},
}
diff --git a/tests/unit/llms/test_openai.py b/tests/unit/models/llms/test_openai.py
similarity index 98%
rename from tests/unit/llms/test_openai.py
rename to tests/unit/models/llms/test_openai.py
index 03fb94c1d3..30caaa86ad 100644
--- a/tests/unit/llms/test_openai.py
+++ b/tests/unit/models/llms/test_openai.py
@@ -25,7 +25,7 @@
from openai.types import Batch
from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException
-from distilabel.llms.openai import _OPENAI_BATCH_API_MAX_FILE_SIZE, OpenAILLM
+from distilabel.models.llms.openai import _OPENAI_BATCH_API_MAX_FILE_SIZE, OpenAILLM
from .utils import DummyUserDetail
@@ -461,7 +461,7 @@ def test_create_jsonl_row(
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"type_info": {
- "module": "distilabel.llms.openai",
+ "module": "distilabel.models.llms.openai",
"name": "OpenAILLM",
},
},
@@ -487,7 +487,7 @@ def test_create_jsonl_row(
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"type_info": {
- "module": "distilabel.llms.openai",
+ "module": "distilabel.models.llms.openai",
"name": "OpenAILLM",
},
},
diff --git a/tests/unit/llms/test_together.py b/tests/unit/models/llms/test_together.py
similarity index 94%
rename from tests/unit/llms/test_together.py
rename to tests/unit/models/llms/test_together.py
index 409f34866f..88208bf6c6 100644
--- a/tests/unit/llms/test_together.py
+++ b/tests/unit/models/llms/test_together.py
@@ -15,7 +15,7 @@
import os
from unittest import mock
-from distilabel.llms.together import TogetherLLM
+from distilabel.models.llms.together import TogetherLLM
class TestTogetherLLM:
@@ -53,7 +53,7 @@ def test_serialization(self) -> None:
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"type_info": {
- "module": "distilabel.llms.together",
+ "module": "distilabel.models.llms.together",
"name": "TogetherLLM",
},
}
diff --git a/tests/unit/llms/test_vertexai.py b/tests/unit/models/llms/test_vertexai.py
similarity index 97%
rename from tests/unit/llms/test_vertexai.py
rename to tests/unit/models/llms/test_vertexai.py
index 38f5933849..d32f773a3c 100644
--- a/tests/unit/llms/test_vertexai.py
+++ b/tests/unit/models/llms/test_vertexai.py
@@ -22,7 +22,7 @@
Part,
)
-from distilabel.llms.vertexai import VertexAILLM
+from distilabel.models.llms.vertexai import VertexAILLM
@patch("vertexai.generative_models.GenerativeModel.generate_content_async")
@@ -120,7 +120,7 @@ def test_serialization(self, _: MagicMock) -> None:
"offline_batch_generation_block_until_done": None,
"use_offline_batch_generation": False,
"type_info": {
- "module": "distilabel.llms.vertexai",
+ "module": "distilabel.models.llms.vertexai",
"name": "VertexAILLM",
},
}
diff --git a/tests/unit/llms/test_vllm.py b/tests/unit/models/llms/test_vllm.py
similarity index 98%
rename from tests/unit/llms/test_vllm.py
rename to tests/unit/models/llms/test_vllm.py
index c1df505126..07c561af86 100644
--- a/tests/unit/llms/test_vllm.py
+++ b/tests/unit/models/llms/test_vllm.py
@@ -23,8 +23,8 @@
from openai.types.completion_choice import CompletionChoice
from pydantic import BaseModel
-from distilabel.llms import vLLM
-from distilabel.llms.vllm import ClientvLLM, _sort_batches
+from distilabel.models.llms import vLLM
+from distilabel.models.llms.vllm import ClientvLLM, _sort_batches
class Character(BaseModel):
diff --git a/tests/unit/llms/utils.py b/tests/unit/models/llms/utils.py
similarity index 100%
rename from tests/unit/llms/utils.py
rename to tests/unit/models/llms/utils.py
diff --git a/tests/unit/models/mixins/__init__.py b/tests/unit/models/mixins/__init__.py
new file mode 100644
index 0000000000..20ce00bda7
--- /dev/null
+++ b/tests/unit/models/mixins/__init__.py
@@ -0,0 +1,14 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
diff --git a/tests/unit/llms/mixins/test_cuda_device_placement.py b/tests/unit/models/mixins/test_cuda_device_placement.py
similarity index 97%
rename from tests/unit/llms/mixins/test_cuda_device_placement.py
rename to tests/unit/models/mixins/test_cuda_device_placement.py
index eb6c178667..bdddabf83e 100644
--- a/tests/unit/llms/mixins/test_cuda_device_placement.py
+++ b/tests/unit/models/mixins/test_cuda_device_placement.py
@@ -19,8 +19,8 @@
import pytest
-from distilabel.llms.base import LLM
-from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
+from distilabel.models.llms.base import LLM
+from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
if TYPE_CHECKING:
from distilabel.steps.tasks.typing import ChatType
diff --git a/tests/unit/llms/mixins/test_magpie.py b/tests/unit/models/mixins/test_magpie.py
similarity index 96%
rename from tests/unit/llms/mixins/test_magpie.py
rename to tests/unit/models/mixins/test_magpie.py
index a470cd1287..9a6f5b2ffa 100644
--- a/tests/unit/llms/mixins/test_magpie.py
+++ b/tests/unit/models/mixins/test_magpie.py
@@ -14,7 +14,7 @@
import pytest
-from distilabel.llms.mixins.magpie import MAGPIE_PRE_QUERY_TEMPLATES
+from distilabel.models.mixins.magpie import MAGPIE_PRE_QUERY_TEMPLATES
from tests.unit.conftest import DummyMagpieLLM
diff --git a/tests/unit/pipeline/test_base.py b/tests/unit/pipeline/test_base.py
index 86db3f5cfb..77faf25d14 100644
--- a/tests/unit/pipeline/test_base.py
+++ b/tests/unit/pipeline/test_base.py
@@ -21,10 +21,12 @@
from unittest import mock
import pytest
+from datasets import Dataset
from fsspec.implementations.local import LocalFileSystem
from pydantic import Field
from upath import UPath
+from distilabel import constants
from distilabel.constants import (
INPUT_QUEUE_ATTR_NAME,
LAST_BATCH_SENT_FLAG,
@@ -126,6 +128,19 @@ def test_context_manager(self) -> None:
assert _GlobalPipelineManager.get_pipeline() is None
+ def test_add_dataset_generator_step(self) -> None:
+ with DummyPipeline() as pipeline:
+ step_1 = DummyStep1()
+
+ dataset = Dataset.from_list(
+ [{"instruction": "Hello"}, {"instruction": "Hello again"}]
+ )
+ pipeline._add_dataset_generator_step(dataset, 123)
+ step = pipeline.dag.get_step("load_data_from_hub_0")[constants.STEP_ATTR_NAME]
+
+ assert step.name in pipeline.dag.get_step_predecessors(step_1.name) # type: ignore
+ assert step.batch_size == 123 # type: ignore
+
@pytest.mark.parametrize("use_cache", [False, True])
def test_load_batch_manager(self, use_cache: bool) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
diff --git a/tests/unit/pipeline/test_ray.py b/tests/unit/pipeline/test_ray.py
index 610f272196..3b4c9f186d 100644
--- a/tests/unit/pipeline/test_ray.py
+++ b/tests/unit/pipeline/test_ray.py
@@ -17,7 +17,7 @@
import pytest
from distilabel.errors import DistilabelUserError
-from distilabel.llms.vllm import vLLM
+from distilabel.models.llms.vllm import vLLM
from distilabel.pipeline.ray import RayPipeline
from distilabel.steps.base import StepResources
from distilabel.steps.tasks.text_generation import TextGeneration
diff --git a/tests/unit/steps/clustering/test_text_clustering.py b/tests/unit/steps/clustering/test_text_clustering.py
index 4b2da96d40..0659da71ec 100644
--- a/tests/unit/steps/clustering/test_text_clustering.py
+++ b/tests/unit/steps/clustering/test_text_clustering.py
@@ -21,7 +21,7 @@
from tests.unit.conftest import DummyAsyncLLM
if TYPE_CHECKING:
- from distilabel.llms.typing import GenerateOutput
+ from distilabel.models.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import FormattedInput
diff --git a/tests/unit/steps/embeddings/test_embedding_generation.py b/tests/unit/steps/embeddings/test_embedding_generation.py
index 66284e0ed9..71264b298b 100644
--- a/tests/unit/steps/embeddings/test_embedding_generation.py
+++ b/tests/unit/steps/embeddings/test_embedding_generation.py
@@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from distilabel.embeddings.sentence_transformers import SentenceTransformerEmbeddings
+from distilabel.models.embeddings.sentence_transformers import (
+ SentenceTransformerEmbeddings,
+)
from distilabel.steps.embeddings.embedding_generation import EmbeddingGeneration
diff --git a/tests/unit/steps/tasks/apigen/test_generator.py b/tests/unit/steps/tasks/apigen/test_generator.py
index a290666a60..efe14ff12f 100644
--- a/tests/unit/steps/tasks/apigen/test_generator.py
+++ b/tests/unit/steps/tasks/apigen/test_generator.py
@@ -21,7 +21,7 @@
from tests.unit.conftest import DummyLLM
if TYPE_CHECKING:
- from distilabel.llms.typing import GenerateOutput
+ from distilabel.models.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import FormattedInput
import json
diff --git a/tests/unit/steps/tasks/evol_instruct/evol_complexity.py/test_base.py b/tests/unit/steps/tasks/evol_instruct/evol_complexity.py/test_base.py
index 54d7b85d43..282b2987f1 100644
--- a/tests/unit/steps/tasks/evol_instruct/evol_complexity.py/test_base.py
+++ b/tests/unit/steps/tasks/evol_instruct/evol_complexity.py/test_base.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from distilabel.llms.base import LLM
+from distilabel.models.llms.base import LLM
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.evol_instruct.evol_complexity.base import (
EvolComplexity,
diff --git a/tests/unit/steps/tasks/evol_instruct/evol_complexity.py/test_generator.py b/tests/unit/steps/tasks/evol_instruct/evol_complexity.py/test_generator.py
index 60d3a9b1a3..35a6d3b22f 100644
--- a/tests/unit/steps/tasks/evol_instruct/evol_complexity.py/test_generator.py
+++ b/tests/unit/steps/tasks/evol_instruct/evol_complexity.py/test_generator.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from distilabel.llms.base import LLM
+from distilabel.models.llms.base import LLM
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.evol_instruct.evol_complexity.generator import (
EvolComplexityGenerator,
diff --git a/tests/unit/steps/tasks/evol_instruct/test_base.py b/tests/unit/steps/tasks/evol_instruct/test_base.py
index 66f67347b1..053bac0a4f 100644
--- a/tests/unit/steps/tasks/evol_instruct/test_base.py
+++ b/tests/unit/steps/tasks/evol_instruct/test_base.py
@@ -15,7 +15,7 @@
import pytest
from pydantic import ValidationError
-from distilabel.llms.base import LLM
+from distilabel.models.llms.base import LLM
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.evol_instruct.base import (
EvolInstruct,
diff --git a/tests/unit/steps/tasks/evol_instruct/test_generator.py b/tests/unit/steps/tasks/evol_instruct/test_generator.py
index 8f86b94908..e87d09a9ce 100644
--- a/tests/unit/steps/tasks/evol_instruct/test_generator.py
+++ b/tests/unit/steps/tasks/evol_instruct/test_generator.py
@@ -15,7 +15,7 @@
import pytest
from pydantic import ValidationError
-from distilabel.llms.base import LLM
+from distilabel.models.llms.base import LLM
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.evol_instruct.generator import (
EvolInstructGenerator,
diff --git a/tests/unit/steps/tasks/evol_quality/test_base.py b/tests/unit/steps/tasks/evol_quality/test_base.py
index 2ac460afc4..c77df8d8ad 100644
--- a/tests/unit/steps/tasks/evol_quality/test_base.py
+++ b/tests/unit/steps/tasks/evol_quality/test_base.py
@@ -15,7 +15,7 @@
import pytest
from pydantic import ValidationError
-from distilabel.llms.base import LLM
+from distilabel.models.llms.base import LLM
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.evol_quality.base import (
EvolQuality,
diff --git a/tests/unit/steps/tasks/magpie/test_base.py b/tests/unit/steps/tasks/magpie/test_base.py
index cc13681f9f..aac4e504f9 100644
--- a/tests/unit/steps/tasks/magpie/test_base.py
+++ b/tests/unit/steps/tasks/magpie/test_base.py
@@ -18,7 +18,7 @@
import pytest
-from distilabel.llms.openai import OpenAILLM
+from distilabel.models.llms.openai import OpenAILLM
from distilabel.steps.tasks.magpie.base import MAGPIE_MULTI_TURN_SYSTEM_PROMPT, Magpie
from tests.unit.conftest import DummyMagpieLLM
diff --git a/tests/unit/steps/tasks/magpie/test_generator.py b/tests/unit/steps/tasks/magpie/test_generator.py
index d1d1426351..22d22e60a2 100644
--- a/tests/unit/steps/tasks/magpie/test_generator.py
+++ b/tests/unit/steps/tasks/magpie/test_generator.py
@@ -14,7 +14,7 @@
import pytest
-from distilabel.llms.openai import OpenAILLM
+from distilabel.models.llms.openai import OpenAILLM
from distilabel.steps.tasks.magpie.generator import MagpieGenerator
from tests.unit.conftest import DummyMagpieLLM
diff --git a/tests/unit/steps/tasks/structured_outputs/test_outlines.py b/tests/unit/steps/tasks/structured_outputs/test_outlines.py
index d2be053aa5..a535081e65 100644
--- a/tests/unit/steps/tasks/structured_outputs/test_outlines.py
+++ b/tests/unit/steps/tasks/structured_outputs/test_outlines.py
@@ -17,7 +17,7 @@
import pytest
from pydantic import BaseModel
-from distilabel.llms.huggingface.transformers import TransformersLLM
+from distilabel.models.llms.huggingface.transformers import TransformersLLM
from distilabel.steps.tasks.structured_outputs.outlines import (
# StructuredOutputType,
model_to_schema,
@@ -65,7 +65,7 @@ class DummyUserTest(BaseModel):
"use_magpie_template": False,
"disable_cuda_device_placement": False,
"type_info": {
- "module": "distilabel.llms.huggingface.transformers",
+ "module": "distilabel.models.llms.huggingface.transformers",
"name": "TransformersLLM",
},
}
@@ -95,7 +95,7 @@ class DummyUserTest(BaseModel):
"use_magpie_template": False,
"disable_cuda_device_placement": False,
"type_info": {
- "module": "distilabel.llms.huggingface.transformers",
+ "module": "distilabel.models.llms.huggingface.transformers",
"name": "TransformersLLM",
},
}
diff --git a/tests/unit/steps/tasks/test_argilla_labeller.py b/tests/unit/steps/tasks/test_argilla_labeller.py
index 926118dd6c..9418e899a5 100644
--- a/tests/unit/steps/tasks/test_argilla_labeller.py
+++ b/tests/unit/steps/tasks/test_argilla_labeller.py
@@ -28,8 +28,6 @@ def fields() -> Dict[str, Any]:
return [
{
"name": "text",
- "description": "The text of the question",
- "title": "The text of the question",
"settings": {"type": "text"},
}
]
@@ -40,8 +38,6 @@ def questions() -> List[Dict[str, Any]]:
return [
{
"name": "label_selection",
- "description": "The class of the question",
- "title": "Is the question a question?",
"settings": {
"type": "label_selection",
"options": [
@@ -52,8 +48,6 @@ def questions() -> List[Dict[str, Any]]:
},
{
"name": "multi_label_selection",
- "description": "The class of the question",
- "title": "Is the question a question?",
"settings": {
"type": "multi_label_selection",
"options": [
@@ -64,8 +58,6 @@ def questions() -> List[Dict[str, Any]]:
},
{
"name": "rating",
- "description": "The class of the question",
- "title": "Is the question a question?",
"settings": {
"type": "rating",
"options": [
@@ -75,8 +67,6 @@ def questions() -> List[Dict[str, Any]]:
},
{
"name": "text",
- "description": "The class of the question",
- "title": "Is the question a question?",
"settings": {
"type": "text",
},
@@ -141,12 +131,9 @@ def test_format_input(
"record": records[0],
}
)
- assert question["description"] in result[-1]["content"]
- assert question["title"] in result[-1]["content"]
if question["settings"]["type"] in [
"label_selection",
"multi_label_selection",
- "span",
"rating",
]:
assert (
diff --git a/tests/unit/steps/tasks/test_decorator.py b/tests/unit/steps/tasks/test_decorator.py
new file mode 100644
index 0000000000..085153c1f8
--- /dev/null
+++ b/tests/unit/steps/tasks/test_decorator.py
@@ -0,0 +1,200 @@
+# Copyright 2023-present, Argilla, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Union
+
+import pytest
+
+from distilabel.errors import DistilabelUserError
+from distilabel.steps.tasks.decorator import task
+from tests.unit.conftest import DummyLLM
+
+
+class TestTaskDecorator:
+ def test_decoraror_raise_if_no_docstring(self) -> None:
+ with pytest.raises(
+ DistilabelUserError,
+ match=r"When using the `task` decorator, including a docstring in the formatting function is mandatory",
+ ):
+
+ @task(inputs=["instruction"], outputs=["response"])
+ def MyTask(
+ output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ return {"response": output}
+
+ def test_decorator_raise_if_docstring_invalid(self) -> None:
+ with pytest.raises(
+ DistilabelUserError,
+ match=r"Formatting function decorated with `task` doesn't follow the expected format.",
+ ):
+
+ @task(inputs=["instruction"], outputs=["response"])
+ def MyTask(
+ output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ """This is not valid"""
+ return {"response": output}
+
+ with pytest.raises(
+ DistilabelUserError,
+ match=r"Formatting function decorated with `task` doesn't follow the expected format.",
+ ):
+
+ @task(inputs=["instruction"], outputs=["response"])
+ def MyTask(
+ output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ """
+ ---
+ - this
+ - is
+ - a
+ - list
+ ---
+ """
+ return {"response": output}
+
+ def test_decorator_raise_if_no_system_prompt_or_user_message_template(self) -> None:
+ with pytest.raises(
+ DistilabelUserError,
+ match=r"The formatting function decorated with `task` must include both the `system_prompt` and `user_message_template` keys in the docstring",
+ ):
+
+ @task(inputs=["instruction"], outputs=["response"])
+ def MyTask(
+ output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ """
+ ---
+ system_prompt: prompt
+ ---
+ """
+ return {"response": output}
+
+ with pytest.raises(
+ DistilabelUserError,
+ match=r"The formatting function decorated with `task` must include both the `system_prompt` and `user_message_template` keys in the docstring",
+ ):
+
+ @task(inputs=["instruction"], outputs=["response"])
+ def MyTask(
+ output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ """
+ ---
+ user_message_template: prompt
+ ---
+ """
+ return {"response": output}
+
+ def test_decorator_raise_if_template_invalid_placeholders(self) -> None:
+ with pytest.raises(
+ DistilabelUserError,
+ match=r"The formatting function decorated with `task` includes invalid placeholders in the extracted `system_prompt`",
+ ):
+
+ @task(inputs=["instruction"], outputs=["response"])
+ def MyTask(
+ output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ """
+ ---
+ system_prompt: |
+ You are an AI assistant designed to {task}
+
+ user_message_template: |
+ {instruction}
+ ---
+ """
+ return {"response": output}
+
+ with pytest.raises(
+ DistilabelUserError,
+ match=r"The formatting function decorated with `task` includes invalid placeholders in the extracted `user_message_template`",
+ ):
+
+ @task(inputs=["task"], outputs=["response"])
+ def MyTask(
+ output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ """
+ ---
+ system_prompt: |
+ You are an AI assistant designed to {task}
+
+ user_message_template: |
+ {instruction}
+ ---
+ """
+ return {"response": output}
+
+ def test_decorator_task(self) -> None:
+ @task(inputs=["task", "instruction"], outputs=["response"])
+ def MyTask(
+ output: Union[str, None], input: Union[Dict[str, Any], None] = None
+ ) -> Dict[str, Any]:
+ """
+ `MyTask` is a simple `Task` for bla bla bla
+
+ ---
+ system_prompt: |
+ You are an AI assistant designed to {task}
+
+ user_message_template: |
+ Text: {instruction}
+ ---
+ """
+ return {"response": output}
+
+ my_task = MyTask(llm=DummyLLM())
+
+ my_task.load()
+
+ assert my_task.inputs == ["task", "instruction"]
+ assert my_task.outputs == ["response"]
+ assert my_task.format_input(
+ {"task": "summarize", "instruction": "The cell..."}
+ ) == [
+ {
+ "role": "system",
+ "content": "You are an AI assistant designed to summarize",
+ },
+ {"role": "user", "content": "Text: The cell..."},
+ ]
+ assert next(
+ my_task.process_applying_mappings(
+ [{"task": "summarize", "instruction": "The cell..."}]
+ )
+ ) == [
+ {
+ "task": "summarize",
+ "instruction": "The cell...",
+ "response": "output",
+ "model_name": "test",
+ "distilabel_metadata": {
+ "raw_input_my_task_0": [
+ {
+ "content": "You are an AI assistant designed to summarize",
+ "role": "system",
+ },
+ {
+ "content": "Text: The cell...",
+ "role": "user",
+ },
+ ],
+ "raw_output_my_task_0": "output",
+ },
+ }
+ ]
diff --git a/tests/unit/steps/tasks/test_generate_embeddings.py b/tests/unit/steps/tasks/test_generate_embeddings.py
index 4cf62f21c8..6318f323db 100644
--- a/tests/unit/steps/tasks/test_generate_embeddings.py
+++ b/tests/unit/steps/tasks/test_generate_embeddings.py
@@ -16,7 +16,7 @@
import pytest
-from distilabel.llms.huggingface.transformers import TransformersLLM
+from distilabel.models.llms.huggingface.transformers import TransformersLLM
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.generate_embeddings import GenerateEmbeddings
diff --git a/tests/unit/steps/tasks/test_improving_text_embeddings.py b/tests/unit/steps/tasks/test_improving_text_embeddings.py
index dfaa247b91..0a153034e9 100644
--- a/tests/unit/steps/tasks/test_improving_text_embeddings.py
+++ b/tests/unit/steps/tasks/test_improving_text_embeddings.py
@@ -17,8 +17,8 @@
import pytest
-from distilabel.llms import LLM
-from distilabel.llms.typing import GenerateOutput
+from distilabel.models.llms.base import LLM
+from distilabel.models.llms.typing import GenerateOutput
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.improving_text_embeddings import (
BitextRetrievalGenerator,
diff --git a/tests/unit/steps/tasks/test_instruction_backtranslation.py b/tests/unit/steps/tasks/test_instruction_backtranslation.py
index 1b2f9adffa..405195ef02 100644
--- a/tests/unit/steps/tasks/test_instruction_backtranslation.py
+++ b/tests/unit/steps/tasks/test_instruction_backtranslation.py
@@ -14,8 +14,8 @@
from typing import Any, List
-from distilabel.llms.base import LLM
-from distilabel.llms.typing import GenerateOutput
+from distilabel.models.llms.base import LLM
+from distilabel.models.llms.typing import GenerateOutput
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.instruction_backtranslation import (
InstructionBacktranslation,
diff --git a/tests/unit/steps/tasks/test_sentence_transformers.py b/tests/unit/steps/tasks/test_sentence_transformers.py
index 9dc6b38ae1..8df92e903d 100644
--- a/tests/unit/steps/tasks/test_sentence_transformers.py
+++ b/tests/unit/steps/tasks/test_sentence_transformers.py
@@ -26,27 +26,6 @@
)
from tests.unit.conftest import DummyAsyncLLM
-# from distilabel.llms.base import LLM, AsyncLLM
-
-# if TYPE_CHECKING:
-# from distilabel.llms.typing import GenerateOutput
-# from distilabel.steps.tasks.typing import FormattedInput
-
-# # Defined here too, so that the serde still works
-# class DummyStructuredLLM(LLM):
-# structured_output: Any = None
-# def load(self) -> None:
-# pass
-
-# @property
-# def model_name(self) -> str:
-# return "test"
-
-# def generate(
-# self, input: "FormattedInput", num_generations: int = 1
-# ) -> "GenerateOutput":
-# return ['{ \n "negative": "negative",\n "positive": "positive"\n}' for _ in range(num_generations)]
-
class TestGenerateSentencePair:
@pytest.mark.parametrize(
diff --git a/tests/unit/steps/tasks/test_structured_generation.py b/tests/unit/steps/tasks/test_structured_generation.py
index a57d0da7df..82b86ee93d 100644
--- a/tests/unit/steps/tasks/test_structured_generation.py
+++ b/tests/unit/steps/tasks/test_structured_generation.py
@@ -17,8 +17,8 @@
from typing_extensions import override
-from distilabel.llms.base import LLM
-from distilabel.llms.typing import GenerateOutput
+from distilabel.models.llms.base import LLM
+from distilabel.models.llms.typing import GenerateOutput
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.structured_generation import StructuredGeneration
from distilabel.steps.tasks.typing import StructuredInput
diff --git a/tests/unit/steps/tasks/test_text_classification.py b/tests/unit/steps/tasks/test_text_classification.py
index e5af171b33..d9c36f58a5 100644
--- a/tests/unit/steps/tasks/test_text_classification.py
+++ b/tests/unit/steps/tasks/test_text_classification.py
@@ -21,7 +21,7 @@
from tests.unit.conftest import DummyAsyncLLM
if TYPE_CHECKING:
- from distilabel.llms.typing import GenerateOutput
+ from distilabel.models.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import FormattedInput
diff --git a/tests/unit/steps/tasks/test_ultrafeedback.py b/tests/unit/steps/tasks/test_ultrafeedback.py
index 5565065d61..46ed061838 100644
--- a/tests/unit/steps/tasks/test_ultrafeedback.py
+++ b/tests/unit/steps/tasks/test_ultrafeedback.py
@@ -16,8 +16,8 @@
import pytest
-from distilabel.llms.base import LLM
-from distilabel.llms.typing import GenerateOutput
+from distilabel.models.llms.base import LLM
+from distilabel.models.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import ChatType
from distilabel.steps.tasks.ultrafeedback import UltraFeedback
diff --git a/tests/unit/test_imports.py b/tests/unit/test_imports.py
index bcede6a03e..a836cceb15 100644
--- a/tests/unit/test_imports.py
+++ b/tests/unit/test_imports.py
@@ -15,7 +15,7 @@
def test_imports() -> None:
# ruff: noqa
- from distilabel.llms import (
+ from distilabel.models.llms import (
AnthropicLLM,
AnyscaleLLM,
AsyncLLM,