Skip to content

Commit

Permalink
Generate deterministic pipeline name when it's not given (#878)
Browse files Browse the repository at this point in the history
* Generate deterministic pipeline name when it's not given

* Use the names of the steps to generate the default pipeline name

* Update test with the steps names

* Add suggestion from code review
  • Loading branch information
plaguss authored Aug 22, 2024
1 parent 6576d1a commit fc5d070
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 29 deletions.
17 changes: 15 additions & 2 deletions src/distilabel/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import signal
import threading
import time
import uuid
from abc import ABC, abstractmethod
from pathlib import Path
from typing import (
Expand Down Expand Up @@ -131,6 +130,7 @@ def get_pipeline(cls) -> Union["BasePipeline", None]:
_STEP_NOT_LOADED_CODE = -999

_ATTRIBUTES_IGNORED_CACHE = ("disable_cuda_device_placement",)
_PIPELINE_DEFAULT_NAME = "__default_pipeline_name__"


class BasePipeline(ABC, RequirementsMixin, _Serializable):
Expand Down Expand Up @@ -189,7 +189,7 @@ def __init__(
Defaults to `None`, but can be helpful to inform in a pipeline to be shared
that this requirements must be installed.
"""
self.name = name or f"pipeline_{str(uuid.uuid4())[:8]}"
self.name = name or _PIPELINE_DEFAULT_NAME
self.description = description
self._enable_metadata = enable_metadata
self.dag = DAG()
Expand Down Expand Up @@ -235,6 +235,12 @@ def __enter__(self) -> Self:
def __exit__(self, exc_type, exc_value, traceback) -> None:
"""Unset the global pipeline instance when exiting a pipeline context."""
_GlobalPipelineManager.set_pipeline(None)
self._set_pipeline_name()

def _set_pipeline_name(self) -> None:
"""Creates a name for the pipeline if it's the default one (if hasn't been set)."""
if self.name == _PIPELINE_DEFAULT_NAME:
self.name = f"pipeline_{'_'.join(self.dag)}"

def _create_signature(self) -> str:
"""Makes a signature (hash) of a pipeline, using the step ids and the adjacency between them.
Expand Down Expand Up @@ -351,6 +357,13 @@ def run(
log_queue=self._log_queue, filename=str(self._cache_location["log_file"])
)

# Set the name of the pipeline if it's the default one. This should be called
# if the pipeline is defined within the context manager, and the run is called
# outside of it. Is here in the following case:
# with Pipeline() as pipeline:
# pipeline.run()
self._set_pipeline_name()

# Validate the pipeline DAG to check that all the steps are chainable, there are
# no missing runtime parameters, batch sizes are correct, etc.
self.dag.validate()
Expand Down
58 changes: 31 additions & 27 deletions tests/unit/pipeline/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,28 +105,29 @@ def test_context_manager(self) -> None:

@pytest.mark.parametrize("use_cache", [False, True])
def test_load_batch_manager(self, use_cache: bool) -> None:
pipeline = DummyPipeline(name="unit-test-pipeline")
pipeline._load_batch_manager(use_cache=True)
pipeline._cache()

with (
mock.patch(
"distilabel.pipeline.base._BatchManager.load_from_cache"
) as mock_load_from_cache,
mock.patch(
"distilabel.pipeline.base._BatchManager.from_dag"
) as mock_from_dag,
):
pipeline._load_batch_manager(use_cache=use_cache)

if use_cache:
mock_load_from_cache.assert_called_once_with(
pipeline._cache_location["batch_manager"]
)
mock_from_dag.assert_not_called()
else:
mock_load_from_cache.assert_not_called()
mock_from_dag.assert_called_once_with(pipeline.dag)
with tempfile.TemporaryDirectory() as temp_dir:
pipeline = DummyPipeline(name="unit-test-pipeline", cache_dir=temp_dir)
pipeline._load_batch_manager(use_cache=True)
pipeline._cache()

with (
mock.patch(
"distilabel.pipeline.base._BatchManager.load_from_cache"
) as mock_load_from_cache,
mock.patch(
"distilabel.pipeline.base._BatchManager.from_dag"
) as mock_from_dag,
):
pipeline._load_batch_manager(use_cache=use_cache)

if use_cache:
mock_load_from_cache.assert_called_once_with(
pipeline._cache_location["batch_manager"]
)
mock_from_dag.assert_not_called()
else:
mock_load_from_cache.assert_not_called()
mock_from_dag.assert_called_once_with(pipeline.dag)

def test_setup_write_buffer(self) -> None:
pipeline = DummyPipeline(name="unit-test-pipeline")
Expand Down Expand Up @@ -1187,13 +1188,16 @@ def test_pipeline_with_dataset_and_generator_step(self):
)

def test_optional_name(self):
import random
from distilabel.pipeline.base import _PIPELINE_DEFAULT_NAME

assert DummyPipeline().name == _PIPELINE_DEFAULT_NAME

random.seed(42)
with DummyPipeline() as pipeline:
name = pipeline.name
assert name.startswith("pipeline")
assert len(name.split("_")[-1]) == 8
gen_step = DummyGeneratorStep()
step1_0 = DummyStep1()
gen_step >> step1_0

assert pipeline.name == "pipeline_dummy_generator_step_0_dummy_step1_0"


class TestPipelineSerialization:
Expand Down

0 comments on commit fc5d070

Please sign in to comment.