Skip to content

Commit

Permalink
Refactor distilabel to import from distilabel.models instead of disti…
Browse files Browse the repository at this point in the history
…label.llms
  • Loading branch information
plaguss committed Oct 24, 2024
1 parent 5b519ac commit 470bf97
Show file tree
Hide file tree
Showing 33 changed files with 70 additions and 70 deletions.
2 changes: 1 addition & 1 deletion src/distilabel/pipeline/ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from distilabel.constants import INPUT_QUEUE_ATTR_NAME, STEP_ATTR_NAME
from distilabel.distiset import create_distiset
from distilabel.errors import DistilabelUserError
from distilabel.llms.vllm import vLLM
from distilabel.models.llms.vllm import vLLM
from distilabel.pipeline.base import BasePipeline, set_pipeline_running_env_variables
from distilabel.pipeline.step_wrapper import _StepWrapper
from distilabel.utils.logging import setup_logging, stop_logging
Expand Down
4 changes: 2 additions & 2 deletions src/distilabel/pipeline/routing_batch_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def routing_batch_function(
Example:
```python
from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM
from distilabel.models import MistralLLM, OpenAILLM, VertexAILLM
from distilabel.pipeline import Pipeline, routing_batch_function
from distilabel.steps import LoadDataFromHub, GroupColumns
Expand Down Expand Up @@ -337,7 +337,7 @@ def sample_n_steps(n: int) -> RoutingBatchFunction:
Example:
```python
from distilabel.llms import MistralLLM, OpenAILLM, VertexAILLM
from distilabel.models import MistralLLM, OpenAILLM, VertexAILLM
from distilabel.pipeline import Pipeline, sample_n_steps
from distilabel.steps import LoadDataFromHub, GroupColumns
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/pipeline/step_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from distilabel.constants import LAST_BATCH_SENT_FLAG
from distilabel.errors import DISTILABEL_DOCS_URL
from distilabel.exceptions import DistilabelOfflineBatchGenerationNotFinishedException
from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.pipeline.batch import _Batch
from distilabel.pipeline.typing import StepLoadStatus
from distilabel.steps.base import GeneratorStep, Step, _Step
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/steps/clustering/text_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class TextClustering(TextClassification, GlobalTask):
Generate labels for a set of texts using clustering:
```python
from distilabel.llms import InferenceEndpointsLLM
from distilabel.models import InferenceEndpointsLLM
from distilabel.steps import UMAP, DBSCAN, TextClustering
from distilabel.pipeline import Pipeline
Expand Down
4 changes: 2 additions & 2 deletions src/distilabel/steps/embeddings/embedding_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from typing import TYPE_CHECKING

from distilabel.embeddings.base import Embeddings
from distilabel.models.embeddings.base import Embeddings
from distilabel.steps.base import Step, StepInput

if TYPE_CHECKING:
Expand Down Expand Up @@ -43,7 +43,7 @@ class EmbeddingGeneration(Step):
Generate sentence embeddings with Sentence Transformers:
```python
from distilabel.embeddings import SentenceTransformerEmbeddings
from distilabel.models import SentenceTransformerEmbeddings
from distilabel.steps import EmbeddingGeneration
embedding_generation = EmbeddingGeneration(
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/steps/embeddings/nearest_neighbour.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class FaissNearestNeighbour(GlobalStep):
Generating embeddings and getting the nearest neighbours:
```python
from distilabel.embeddings.sentence_transformers import SentenceTransformerEmbeddings
from distilabel.models import SentenceTransformerEmbeddings
from distilabel.pipeline import Pipeline
from distilabel.steps import EmbeddingGeneration, FaissNearestNeighbour, LoadDataFromHub
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/steps/reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from pydantic import Field, PrivateAttr, SecretStr

from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.steps.base import Step, StepInput
from distilabel.utils.huggingface import HF_TOKEN_ENV_VAR

Expand Down
4 changes: 2 additions & 2 deletions src/distilabel/steps/tasks/apigen/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class APIGenGenerator(Task):
```python
from distilabel.steps.tasks import ApiGenGenerator
from distilabel.llms import InferenceEndpointsLLM
from distilabel.models import InferenceEndpointsLLM
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
Expand Down Expand Up @@ -138,7 +138,7 @@ class APIGenGenerator(Task):
```python
from distilabel.steps.tasks import ApiGenGenerator
from distilabel.llms import InferenceEndpointsLLM
from distilabel.models import InferenceEndpointsLLM
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
Expand Down
4 changes: 2 additions & 2 deletions src/distilabel/steps/tasks/apigen/semantic_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class APIGenSemanticChecker(Task):
```python
from distilabel.steps.tasks import APIGenSemanticChecker
from distilabel.llms import InferenceEndpointsLLM
from distilabel.models import InferenceEndpointsLLM
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
Expand Down Expand Up @@ -125,7 +125,7 @@ class APIGenSemanticChecker(Task):
```python
from distilabel.steps.tasks import APIGenSemanticChecker
from distilabel.llms import InferenceEndpointsLLM
from distilabel.models import InferenceEndpointsLLM
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
Expand Down
6 changes: 3 additions & 3 deletions src/distilabel/steps/tasks/argilla_labeller.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class ArgillaLabeller(Task):
import argilla as rg
from argilla import Suggestion
from distilabel.steps.tasks import ArgillaLabeller
from distilabel.llms.huggingface import InferenceEndpointsLLM
from distilabel.models import InferenceEndpointsLLM
# Get information from Argilla dataset definition
dataset = rg.Dataset("my_dataset")
Expand Down Expand Up @@ -138,7 +138,7 @@ class ArgillaLabeller(Task):
```python
import argilla as rg
from distilabel.steps.tasks import ArgillaLabeller
from distilabel.llms.huggingface import InferenceEndpointsLLM
from distilabel.models import InferenceEndpointsLLM
# Get information from Argilla dataset definition
dataset = rg.Dataset("my_dataset")
Expand Down Expand Up @@ -186,7 +186,7 @@ class ArgillaLabeller(Task):
```python
import argilla as rg
from distilabel.steps.tasks import ArgillaLabeller
from distilabel.llms.huggingface import InferenceEndpointsLLM
from distilabel.models import InferenceEndpointsLLM
# Overwrite default prompts and instructions
labeller = ArgillaLabeller(
Expand Down
10 changes: 5 additions & 5 deletions src/distilabel/steps/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

from distilabel.constants import DISTILABEL_METADATA_KEY
from distilabel.errors import DistilabelUserError
from distilabel.llms.base import LLM
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.llms.base import LLM
from distilabel.steps.base import (
GeneratorStep,
GlobalStep,
Expand All @@ -33,7 +33,7 @@
from distilabel.utils.dicts import group_dicts

if TYPE_CHECKING:
from distilabel.llms.typing import GenerateOutput
from distilabel.models.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import ChatType, FormattedInput
from distilabel.steps.typing import StepOutput

Expand Down Expand Up @@ -245,8 +245,8 @@ def _set_default_structured_output(self) -> None:
if self.use_default_structured_output and not self.llm.structured_output:
# In case the default structured output is required, we have to set it before
# the LLM is loaded
from distilabel.llms import InferenceEndpointsLLM
from distilabel.llms.base import AsyncLLM
from distilabel.models.llms import InferenceEndpointsLLM
from distilabel.models.llms.base import AsyncLLM

def check_dependency(module_name: str) -> None:
if not importlib.util.find_spec(module_name):
Expand Down Expand Up @@ -301,7 +301,7 @@ def print(self, sample_input: Optional["ChatType"] = None) -> None:
```python
from distilabel.steps.tasks import URIAL
from distilabel.llms.huggingface import InferenceEndpointsLLM
from distilabel.models.llms.huggingface import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
urial = URIAL(
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/steps/tasks/clair.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class CLAIR(Task):
```python
from distilabel.steps.tasks import CLAIR
from distilabel.llms.huggingface import InferenceEndpointsLLM
from distilabel.models import InferenceEndpointsLLM
llm=InferenceEndpointsLLM(
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct",
Expand Down
4 changes: 2 additions & 2 deletions src/distilabel/steps/tasks/complexity_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class ComplexityScorer(Task):
```python
from distilabel.steps.tasks import ComplexityScorer
from distilabel.llms.huggingface import InferenceEndpointsLLM
from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
scorer = ComplexityScorer(
Expand All @@ -91,7 +91,7 @@ class ComplexityScorer(Task):
```python
from distilabel.steps.tasks import ComplexityScorer
from distilabel.llms.huggingface import InferenceEndpointsLLM
from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
scorer = ComplexityScorer(
Expand Down
6 changes: 3 additions & 3 deletions src/distilabel/steps/tasks/evol_instruct/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class EvolInstruct(Task):
```python
from distilabel.steps.tasks import EvolInstruct
from distilabel.llms.huggingface import InferenceEndpointsLLM
from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
evol_instruct = EvolInstruct(
Expand All @@ -96,7 +96,7 @@ class EvolInstruct(Task):
```python
from distilabel.steps.tasks import EvolInstruct
from distilabel.llms.huggingface import InferenceEndpointsLLM
from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
evol_instruct = EvolInstruct(
Expand Down Expand Up @@ -124,7 +124,7 @@ class EvolInstruct(Task):
```python
from distilabel.steps.tasks import EvolInstruct
from distilabel.llms.huggingface import InferenceEndpointsLLM
from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
evol_instruct = EvolInstruct(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class EvolComplexity(EvolInstruct):
```python
from distilabel.steps.tasks import EvolComplexity
from distilabel.llms.huggingface import InferenceEndpointsLLM
from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
evol_complexity = EvolComplexity(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class EvolComplexityGenerator(EvolInstructGenerator):
```python
from distilabel.steps.tasks import EvolComplexityGenerator
from distilabel.llms.huggingface import InferenceEndpointsLLM
from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
evol_complexity_generator = EvolComplexityGenerator(
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/steps/tasks/evol_instruct/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class EvolInstructGenerator(GeneratorTask):
```python
from distilabel.steps.tasks import EvolInstructGenerator
from distilabel.llms.huggingface import InferenceEndpointsLLM
from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
evol_instruct_generator = EvolInstructGenerator(
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/steps/tasks/evol_quality/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class EvolQuality(Task):
```python
from distilabel.steps.tasks import EvolQuality
from distilabel.llms.huggingface import InferenceEndpointsLLM
from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
evol_quality = EvolQuality(
Expand Down
4 changes: 2 additions & 2 deletions src/distilabel/steps/tasks/generate_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import TYPE_CHECKING, Any, Dict

from distilabel.errors import DistilabelUserError
from distilabel.llms.base import LLM
from distilabel.models.llms.base import LLM
from distilabel.steps.base import Step, StepInput
from distilabel.utils.chat import is_openai_format

Expand Down Expand Up @@ -54,7 +54,7 @@ class GenerateEmbeddings(Step):
```python
from distilabel.steps.tasks import GenerateEmbeddings
from distilabel.llms.huggingface import TransformersLLM
from distilabel.models.llms.huggingface import TransformersLLM
# Consider this as a placeholder for your actual LLM.
embedder = GenerateEmbeddings(
Expand Down
2 changes: 1 addition & 1 deletion src/distilabel/steps/tasks/genstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class Genstruct(Task):
```python
from distilabel.steps.tasks import Genstruct
from distilabel.llms.huggingface import InferenceEndpointsLLM
from distilabel.models import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
genstruct = Genstruct(
Expand Down
8 changes: 4 additions & 4 deletions src/distilabel/steps/tasks/magpie/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
from pydantic import Field, PositiveInt, field_validator

from distilabel.errors import DistilabelUserError
from distilabel.llms.base import LLM
from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
from distilabel.mixins.runtime_parameters import (
RuntimeParameter,
RuntimeParametersMixin,
)
from distilabel.models.llms.base import LLM
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.steps.base import StepInput
from distilabel.steps.tasks.base import Task

Expand Down Expand Up @@ -404,7 +404,7 @@ class Magpie(Task, MagpieBase):
Generating instructions with Llama 3 8B Instruct and TransformersLLM:
```python
from distilabel.llms import TransformersLLM
from distilabel.models import TransformersLLM
from distilabel.steps.tasks import Magpie
magpie = Magpie(
Expand Down Expand Up @@ -443,7 +443,7 @@ class Magpie(Task, MagpieBase):
Generating conversations with Llama 3 8B Instruct and TransformersLLM:
```python
from distilabel.llms import TransformersLLM
from distilabel.models import TransformersLLM
from distilabel.steps.tasks import Magpie
magpie = Magpie(
Expand Down
8 changes: 4 additions & 4 deletions src/distilabel/steps/tasks/magpie/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from typing_extensions import override

from distilabel.errors import DistilabelUserError
from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.mixins.magpie import MagpieChatTemplateMixin
from distilabel.steps.tasks.base import GeneratorTask
from distilabel.steps.tasks.magpie.base import MagpieBase

Expand Down Expand Up @@ -98,7 +98,7 @@ class MagpieGenerator(GeneratorTask, MagpieBase):
Generating instructions with Llama 3 8B Instruct and TransformersLLM:
```python
from distilabel.llms import TransformersLLM
from distilabel.models import TransformersLLM
from distilabel.steps.tasks import MagpieGenerator
generator = MagpieGenerator(
Expand Down Expand Up @@ -130,7 +130,7 @@ class MagpieGenerator(GeneratorTask, MagpieBase):
Generating a conversation with Llama 3 8B Instruct and TransformersLLM:
```python
from distilabel.llms import TransformersLLM
from distilabel.models import TransformersLLM
from distilabel.steps.tasks import MagpieGenerator
generator = MagpieGenerator(
Expand Down Expand Up @@ -210,7 +210,7 @@ class MagpieGenerator(GeneratorTask, MagpieBase):
Generating with system prompts with probabilities:
```python
from distilabel.llms import InferenceEndpointsLLM
from distilabel.models import InferenceEndpointsLLM
from distilabel.steps.tasks import MagpieGenerator
magpie = MagpieGenerator(
Expand Down
Loading

0 comments on commit 470bf97

Please sign in to comment.