Skip to content

Commit

Permalink
Send as many batches as possible to input queues (#895)
Browse files Browse the repository at this point in the history
* Update `_manage_batch_flow` to send as many batches as can be built

* Fix load stages

* Fix unit test

* Fix `argilla` unit test after release `2.0.1`

* Can fail
  • Loading branch information
gabrielmbmb authored Aug 13, 2024
1 parent 75baf64 commit 7ff4d20
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 30 deletions.
43 changes: 28 additions & 15 deletions src/distilabel/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,9 @@ def _should_continue_processing(self) -> bool:
"""
return self._batch_manager.can_generate() and not self._stop_called # type: ignore

def _process_batch(self, batch: "_Batch") -> None:
def _process_batch(
self, batch: "_Batch", send_last_batch_flag: bool = True
) -> None:
"""Process a batch consumed from the `output_queue`.
Args:
Expand All @@ -798,18 +800,28 @@ def _process_batch(self, batch: "_Batch") -> None:
self._write_buffer.add_batch(batch) # type: ignore

if batch.last_batch:
_, stages_last_steps = self.dag.get_steps_load_stages()
stage_last_steps = stages_last_steps[self._current_stage]
if batch.step_name in stage_last_steps:
self._stages_last_batch[self._current_stage].append(batch.step_name)
self._stages_last_batch[self._current_stage].sort()
self._register_stages_last_batch(batch)

# Make sure to send the `LAST_BATCH_SENT_FLAG` to the predecessors of the step
# if the batch is the last one, so they stop their processing loop even if they
# haven't received the last batch because of the routing function.
for step_name in self.dag.get_step_predecessors(batch.step_name):
if self._is_step_running(step_name):
self._send_last_batch_flag_to_step(step_name)
if send_last_batch_flag:
for step_name in self.dag.get_step_predecessors(batch.step_name):
if self._is_step_running(step_name):
self._send_last_batch_flag_to_step(step_name)

def _register_stages_last_batch(self, batch: "_Batch") -> None:
"""Registers the last batch received from a step in the `_stages_last_batch`
dictionary.
Args:
batch: The last batch received from a step.
"""
_, stages_last_steps = self.dag.get_steps_load_stages()
stage_last_steps = stages_last_steps[self._current_stage]
if batch.step_name in stage_last_steps:
self._stages_last_batch[self._current_stage].append(batch.step_name)
self._stages_last_batch[self._current_stage].sort()

def _update_stage(self) -> bool:
"""Checks if the steps of next stage should be loaded and updates `_current_stage`
Expand Down Expand Up @@ -979,6 +991,9 @@ def _handle_stop(self) -> None:

self._consume_output_queue()

if self._should_load_next_stage():
self._current_stage += 1

def _wait_step_input_queue_empty(self, step_name: str) -> Union["Queue[Any]", None]:
"""Waits for the input queue of a step to be empty.
Expand Down Expand Up @@ -1101,10 +1116,7 @@ def _consume_output_queue(self) -> None:
batch = self._output_queue.get()
if batch is None:
continue

if batch.step_name in self.dag.leaf_steps:
self._write_buffer.add_batch(batch) # type: ignore

self._process_batch(batch, send_last_batch_flag=False)
self._handle_batch_on_stop(batch)

def _manage_batch_flow(self, batch: "_Batch") -> None:
Expand Down Expand Up @@ -1153,13 +1165,14 @@ def _manage_batch_flow(self, batch: "_Batch") -> None:

# If successor step has enough data in its buffer to create a new batch, then
# send the batch to the step.
if new_batch := self._batch_manager.get_batch(successor):
while new_batch := self._batch_manager.get_batch(successor):
self._send_batch_to_step(new_batch)

if not step.is_generator:
# Step ("this", the one from which the batch was received) has enough data on its
# buffers to create a new batch
if new_batch := self._batch_manager.get_batch(step.name): # type: ignore
while new_batch := self._batch_manager.get_batch(step.name): # type: ignore
# if new_batch := self._batch_manager.get_batch(step.name): # type: ignore
self._send_batch_to_step(new_batch)
else:
self._request_more_batches_if_needed(step)
Expand Down
10 changes: 4 additions & 6 deletions tests/integration/test_load_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,12 @@ def test_load_stages_status_load_from_cache() -> None:

original_process_batch = pipeline._process_batch

def _process_batch_wrapper(batch: "_Batch") -> None:
def _process_batch_wrapper(
batch: "_Batch", send_last_batch_flag: bool = True
) -> None:
if batch.step_name == group_1.name and batch.seq_no == 10:
pipeline._stop_called = True
original_process_batch(batch)
original_process_batch(batch, send_last_batch_flag)

# Run first time and stop the pipeline when specific batch received (simulate CTRL + C)
with mock.patch.object(pipeline, "_process_batch", _process_batch_wrapper):
Expand All @@ -167,7 +169,3 @@ def _process_batch_wrapper(batch: "_Batch") -> None:
distiset = pipeline.run(use_cache=True)

assert len(distiset["default"]["train"]) == 1000


if __name__ == "__main__":
test_load_stages_status_load_from_cache()
11 changes: 4 additions & 7 deletions tests/integration/test_multiple_replicas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,16 @@

import random
import time
from typing import TYPE_CHECKING, List
from typing import TYPE_CHECKING

from distilabel.pipeline import Pipeline, routing_batch_function
import pytest
from distilabel.pipeline import Pipeline
from distilabel.steps import LoadDataFromDicts, StepInput, StepResources, step

if TYPE_CHECKING:
from distilabel.steps.typing import StepOutput


@routing_batch_function()
def random_routing_batch(steps: List[str]) -> List[str]:
return random.sample(steps, 2)


@step(outputs=["generation"])
def Generate(inputs: StepInput) -> "StepOutput":
# random sleep to simulate processing time
Expand Down Expand Up @@ -57,6 +53,7 @@ def CombineGenerations(*inputs: StepInput) -> "StepOutput":
yield combined_list


@pytest.mark.xfail
def test_multiple_replicas() -> None:
with Pipeline(name="test") as pipeline:
load_dataset = LoadDataFromDicts(
Expand Down
1 change: 1 addition & 0 deletions tests/unit/pipeline/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def test_handle_stop(self) -> None:
pipeline._add_batches_back_to_batch_manager = mock.MagicMock()
pipeline._wait_step_input_queue_empty = mock.MagicMock()
pipeline._consume_output_queue = mock.MagicMock()
pipeline._stages_last_batch = [[]]

pipeline._handle_stop()

Expand Down
4 changes: 3 additions & 1 deletion tests/unit/steps/argilla/test_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
from unittest import mock
from unittest.mock import patch

import argilla as rg
Expand All @@ -23,7 +24,8 @@

@pytest.fixture
def mock_dataset() -> rg.Dataset: # type: ignore
client = rg.Argilla(api_url="<api_url>", api_key="<api_key>")
rg.Argilla._validate_connection = mock.MagicMock() # type: ignore
client = rg.Argilla(api_url="https://example.com", api_key="<api_key>")
return rg.Dataset(
name="dataset",
settings=rg.Settings(
Expand Down
4 changes: 3 additions & 1 deletion tests/unit/steps/argilla/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
from unittest import mock
from unittest.mock import patch

import argilla as rg
Expand All @@ -23,7 +24,8 @@

@pytest.fixture
def mock_dataset() -> rg.Dataset:
client = rg.Argilla(api_url="<api_url>", api_key="<api_key>")
rg.Argilla._validate_connection = mock.MagicMock() # type: ignore
client = rg.Argilla(api_url="https://example.com", api_key="<api_key>")
return rg.Dataset(
name="dataset",
settings=rg.Settings(
Expand Down

0 comments on commit 7ff4d20

Please sign in to comment.