Skip to content

Commit

Permalink
Fix prepend batches (#696)
Browse files Browse the repository at this point in the history
* Add `built_batches` attribute

* Fix saving `built_batches` and tests
  • Loading branch information
gabrielmbmb authored Jun 5, 2024
1 parent e4a9609 commit 062f4fb
Show file tree
Hide file tree
Showing 2 changed files with 263 additions and 38 deletions.
95 changes: 67 additions & 28 deletions src/distilabel/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,9 @@ class _BatchManagerStep(_Serializable):
If `None`, then `accumulate` must be `True`. Defaults to `None`.
data: A dictionary with the predecessor step name as the key and a list of
dictionaries (rows) as the value.
built_batches: A list with the batches that were built and sent to the step queue,
but the step was stopped before processing the batch, so the batch doesn't get
lost. Defaults to an empty list.
seq_no: The sequence number of the next batch to be created. It will be
incremented for each batch created.
last_batch_received: A list with the names of the steps that sent the last
Expand All @@ -850,6 +853,7 @@ class _BatchManagerStep(_Serializable):
accumulate: bool
input_batch_size: Union[int, None] = None
data: Dict[str, List[_Batch]] = field(default_factory=dict)
built_batches: List[_Batch] = field(default_factory=list)
seq_no: int = 0
last_batch_received: List[str] = field(default_factory=list)
convergence_step: bool = False
Expand All @@ -864,13 +868,15 @@ def add_batch(self, batch: _Batch, prepend: bool = False) -> None:
Args:
batch: The output batch of an step to be processed by the step.
prepend: If `True`, the content of the batch will be added at the start of
the buffer.
prepend: If `True`, the content of the batch will be added to the `built_batches`
list. This is done so if a `_Batch` was already built and send to the step
queue, and the step is stopped before processing the batch, the batch doesn't
get lost. Defaults to `False`.
"""
from_step = batch.step_name

if prepend:
self.data[from_step].insert(0, batch)
self.built_batches.append(batch)
else:
self.data[from_step].append(batch)

Expand All @@ -888,6 +894,11 @@ def get_batch(self) -> Union[_Batch, None]:
if not self._ready_to_create_batch():
return None

# If there are batches in the `built_batches` list, then return the first one
# and remove it from the list.
if self.built_batches:
return self.built_batches.pop(0)

# `_last_batch` must be called before `_get_data`, as `_get_data` will update the
# list of data which is used to determine if the batch to be created is the last one.
# TODO: remove `_last_batch` method and integrate logic in `_get_data`
Expand Down Expand Up @@ -1326,6 +1337,7 @@ def _model_dump(self, obj: Any, **kwargs: Any) -> Dict[str, Any]:
step_name: [batch.dump(**kwargs) for batch in batches]
for step_name, batches in self.data.items()
},
"built_batches": [batch.dump(**kwargs) for batch in self.built_batches],
"seq_no": self.seq_no,
"last_batch_received": self.last_batch_received,
"convergence_step": self.convergence_step,
Expand Down Expand Up @@ -1549,6 +1561,29 @@ def cache(self, path: "StrOrPath") -> None:
path: The path to the file where the `_BatchManager` will be cached. If `None`,
then the `_BatchManager` will be cached in the default cache folder.
"""

def save_batch(
batches_dir: Path, batch_dump: Dict[str, Any], batch_list: List[_Batch]
) -> Path:
seq_no = batch_dump["seq_no"]
data_hash = batch_dump["data_hash"]
batch_file = batches_dir / f"batch_{seq_no}_{data_hash}.json"

# Save the batch if it doesn't exist
if not batch_file.exists():
# Get the data of the batch before saving it
batch = next(batch for batch in batch_list if batch.seq_no == seq_no)
batch_dump["data"] = batch.data
self.save(path=batch_file, format="json", dump=batch_dump)

return batch_file

def remove_files(keep_files: List[str], dir: Path) -> None:
files = list_files_in_dir(dir, key=None)
remove = set(files) - {Path(file) for file in keep_files}
for file in remove:
file.unlink()

path = Path(path)

# Do not include `_Batch` data so `dump` is fast
Expand All @@ -1564,41 +1599,41 @@ def cache(self, path: "StrOrPath") -> None:
batch_manager_step_dir = path.parent / "batch_manager_steps" / step_name
batch_manager_step_dir.mkdir(parents=True, exist_ok=True)

# Store each built `_Batch` in a separate file
built_batches_dir = batch_manager_step_dir / "built_batches"
built_batches_dir.mkdir(parents=True, exist_ok=True)
step_dump["built_batches"] = [
str(
save_batch(
batches_dir=built_batches_dir,
batch_dump=batch_dump,
batch_list=self._steps[step_name].built_batches,
)
)
for batch_dump in step_dump["built_batches"]
]
# Remove built `_Batch`es that were consumed from cache
remove_files(step_dump["built_batches"], built_batches_dir)

# Store each `_BatchManagerStep` `_Batch`es in a separate file
for buffered_step_name in step_dump["data"]:
step_batches_dir = batch_manager_step_dir / buffered_step_name
step_batches_dir.mkdir(parents=True, exist_ok=True)

# Store each `_Batch` in a separate file
keep_batches = []
for batch_dump in step_dump["data"][buffered_step_name]:
# Generate a hash for the data of the batch
seq_no = batch_dump["seq_no"]
data_hash = batch_dump["data_hash"]
batch_file = step_batches_dir / f"batch_{seq_no}_{data_hash}.json"

# Save the batch if it doesn't exist
if not batch_file.exists():
# Get the data of the batch before saving it
batch = next(
batch
for batch in self._steps[step_name].data[buffered_step_name]
if batch.seq_no == seq_no
)
batch_dump["data"] = batch.data
self.save(path=batch_file, format="json", dump=batch_dump)

keep_batches.append(batch_file)

step_dump["data"][buffered_step_name] = [
str(file_batch) for file_batch in keep_batches
str(
save_batch(
batches_dir=step_batches_dir,
batch_dump=batch_dump,
batch_list=self._steps[step_name].data[buffered_step_name],
)
)
for batch_dump in step_dump["data"][buffered_step_name]
]

# Remove `_Batch`es that were consumed from cache
files = list_files_in_dir(step_batches_dir, key=None)
remove = set(files) - set(keep_batches)
for file in remove:
file.unlink()
remove_files(step_dump["data"][buffered_step_name], step_batches_dir)

# Store the `_BatchManagerStep` info
batch_manager_step_file = str(
Expand Down Expand Up @@ -1628,6 +1663,10 @@ def load_from_cache(cls, path: "StrOrPath") -> "_BatchManager":
steps[step_name] = read_json(step_file)

# Read each `_Batch` from file
steps[step_name]["built_batches"] = [
read_json(batch) for batch in steps[step_name]["built_batches"]
]

for buffered_step_name, batch_files in steps[step_name]["data"].items():
steps[step_name]["data"][buffered_step_name] = [
read_json(batch_file) for batch_file in batch_files
Expand Down
Loading

0 comments on commit 062f4fb

Please sign in to comment.