Skip to content

Commit

Permalink
Fix empty load stage when two GlobalSteps are chained (#945)
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb authored Sep 3, 2024
1 parent 56b4036 commit ebd2bb7
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/distilabel/pipeline/_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,13 +294,19 @@ def _get_stage_last_steps(stage_steps: List[str]) -> List[str]:
current_stage = []
stages_last_steps = []

for step_name in nx.topological_sort(self.G):
steps_sorted = list(nx.topological_sort(self.G))
for i, step_name in enumerate(steps_sorted):
step: "_Step" = self.get_step(step_name)[STEP_ATTR_NAME]
if not step.is_global:
current_stage.append(step_name)
else:
stages.append(current_stage)
stages_last_steps.append(_get_stage_last_steps(current_stage))
previous_step = None
if i > 0:
previous_step_name = steps_sorted[i - 1]
previous_step = self.get_step(previous_step_name)[STEP_ATTR_NAME]
if not previous_step or not previous_step.is_global:
stages.append(current_stage)
stages_last_steps.append(_get_stage_last_steps(current_stage))
stages.append([step_name])
stages_last_steps.append([step_name])
current_stage = []
Expand Down
31 changes: 31 additions & 0 deletions tests/unit/pipeline/test_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,37 @@ def test_get_steps_load_stages(self) -> None:
],
)

def test_get_steps_load_stages_global_steps_chained(self) -> None:
with Pipeline(name="dummy") as pipeline:
generator = DummyGeneratorStep(name="dummy_generator_step")
dummies_0 = [DummyStep1(name=f"dummy_step_0_{i}") for i in range(3)]
global_0 = DummyGlobalStep(name="global_0")
global_1 = DummyGlobalStep(name="global_1")

generator >> dummies_0 >> global_0 >> global_1

assert pipeline.dag.get_steps_load_stages() == (
[
[
"dummy_generator_step",
"dummy_step_0_0",
"dummy_step_0_1",
"dummy_step_0_2",
],
["global_0"],
["global_1"],
],
[
[
"dummy_step_0_0",
"dummy_step_0_1",
"dummy_step_0_2",
],
["global_0"],
["global_1"],
],
)

def test_get_steps_load_stages_simple(self) -> None:
with Pipeline(name="dummy") as pipeline:
generator = DummyGeneratorStep(name="dummy_generator_step")
Expand Down

0 comments on commit ebd2bb7

Please sign in to comment.