Skip to content

Commit

Permalink
Fix saving built_batches and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Jun 5, 2024
1 parent 0e5a4ee commit 8e26025
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 26 deletions.
56 changes: 30 additions & 26 deletions src/distilabel/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1578,6 +1578,12 @@ def save_batch(

return batch_file

def remove_files(keep_files: List[str], dir: Path) -> None:
files = list_files_in_dir(dir, key=None)
remove = set(files) - {Path(file) for file in keep_files}
for file in remove:
file.unlink()

path = Path(path)

# Do not include `_Batch` data so `dump` is fast
Expand All @@ -1593,43 +1599,41 @@ def save_batch(
batch_manager_step_dir = path.parent / "batch_manager_steps" / step_name
batch_manager_step_dir.mkdir(parents=True, exist_ok=True)

# Store each `_BatchManagerStep` `_Batch`es in a separate file
for buffered_step_name in step_dump["data"]:
step_batches_dir = batch_manager_step_dir / buffered_step_name
step_batches_dir.mkdir(parents=True, exist_ok=True)

# Store each `_Batch` in a separate file
built_batches_keep = [
# Store each built `_Batch` in a separate file
built_batches_dir = batch_manager_step_dir / "built_batches"
built_batches_dir.mkdir(parents=True, exist_ok=True)
step_dump["built_batches"] = [
str(
save_batch(
batches_dir=step_batches_dir,
batches_dir=built_batches_dir,
batch_dump=batch_dump,
batch_list=self._steps[step_name].built_batches,
)
for batch_dump in step_dump["built_batches"]
]
)
for batch_dump in step_dump["built_batches"]
]
# Remove built `_Batch`es that were consumed from cache
remove_files(step_dump["built_batches"], built_batches_dir)

step_dump["built_batches"] = [
str(file_batch) for file_batch in built_batches_keep
]
# Store each `_BatchManagerStep` `_Batch`es in a separate file
for buffered_step_name in step_dump["data"]:
step_batches_dir = batch_manager_step_dir / buffered_step_name
step_batches_dir.mkdir(parents=True, exist_ok=True)

data_keep_batches = [
save_batch(
batches_dir=step_batches_dir,
batch_dump=batch_dump,
batch_list=self._steps[step_name].data[buffered_step_name],
# Store each `_Batch` in a separate file
step_dump["data"][buffered_step_name] = [
str(
save_batch(
batches_dir=step_batches_dir,
batch_dump=batch_dump,
batch_list=self._steps[step_name].data[buffered_step_name],
)
)
for batch_dump in step_dump["data"][buffered_step_name]
]

step_dump["data"][buffered_step_name] = [
str(file_batch) for file_batch in data_keep_batches
]

# Remove `_Batch`es that were consumed from cache
files = list_files_in_dir(step_batches_dir, key=None)
remove = set(files) - set(built_batches_keep + data_keep_batches)
for file in remove:
file.unlink()
remove_files(step_dump["data"][buffered_step_name], step_batches_dir)

# Store the `_BatchManagerStep` info
batch_manager_step_file = str(
Expand Down
162 changes: 162 additions & 0 deletions tests/unit/pipeline/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,13 +1544,22 @@ def test_dump(self) -> None:
data_hash="hash1",
size=7,
)
batch_step_3 = _Batch(
seq_no=0,
step_name="step3",
last_batch=True,
data=[[{"c": 1}, {"c": 2}, {"c": 3}, {"c": 4}, {"c": 5}]],
data_hash="hash2",
size=5,
)
batch_manager_step = _BatchManagerStep(
step_name="step3",
accumulate=True,
data={
"step1": [batch_step_1],
"step2": [batch_step_2],
},
built_batches=[batch_step_3],
)
assert batch_manager_step.dump() == {
"step_name": "step3",
Expand Down Expand Up @@ -1613,6 +1622,23 @@ def test_dump(self) -> None:
}
],
},
"built_batches": [
{
"seq_no": 0,
"step_name": "step3",
"last_batch": True,
"data": [[{"c": 1}, {"c": 2}, {"c": 3}, {"c": 4}, {"c": 5}]],
"data_hash": "hash2",
"size": 5,
"accumulated": False,
"batch_routed_to": [],
"created_from": {},
"type_info": {
"module": "distilabel.pipeline.base",
"name": "_Batch",
},
}
],
"seq_no": 0,
"last_batch_received": [],
"next_expected_created_from_batch_seq_no": 0,
Expand Down Expand Up @@ -1973,13 +1999,22 @@ def test_can_generate(self) -> None:
assert not batch_manager.can_generate()

def test_dump(self) -> None:
built_batch = _Batch(
seq_no=0,
last_batch=False,
step_name="step3",
data=[[]],
data_hash="hash",
)

batch_manager = _BatchManager(
steps={
"step3": _BatchManagerStep(
step_name="step3",
accumulate=False,
input_batch_size=5,
data={"step1": [], "step2": []},
built_batches=[built_batch],
seq_no=1,
)
},
Expand Down Expand Up @@ -2008,6 +2043,23 @@ def test_dump(self) -> None:
"convergence_step_batches_consumed": {},
"input_batch_size": 5,
"data": {"step1": [], "step2": []},
"built_batches": [
{
"seq_no": 0,
"step_name": "step3",
"last_batch": False,
"data": [[]],
"data_hash": "hash",
"size": 0,
"accumulated": False,
"batch_routed_to": [],
"created_from": {},
"type_info": {
"module": "distilabel.pipeline.base",
"name": "_Batch",
},
}
],
"seq_no": 1,
"last_batch_received": [],
"next_expected_created_from_batch_seq_no": 0,
Expand Down Expand Up @@ -2268,6 +2320,31 @@ def test_cache(self) -> None:
}
],
},
"built_batches": [
{
"seq_no": 0,
"step_name": "step1",
"last_batch": False,
"data": [
[
{"a": 1},
{"a": 2},
{"a": 3},
{"a": 4},
{"a": 5},
]
],
"data_hash": "1234",
"size": 5,
"accumulated": False,
"batch_routed_to": [],
"created_from": {},
"type_info": {
"module": "distilabel.pipeline.base",
"name": "_Batch",
},
}
],
"seq_no": 0,
"last_batch_received": [],
"type_info": {
Expand Down Expand Up @@ -2310,6 +2387,31 @@ def test_cache(self) -> None:
}
],
},
"built_batches": [
{
"seq_no": 0,
"step_name": "step1",
"last_batch": False,
"data": [
[
{"a": 1},
{"a": 2},
{"a": 3},
{"a": 4},
{"a": 5},
]
],
"data_hash": "1234",
"size": 5,
"accumulated": False,
"batch_routed_to": [],
"created_from": {},
"type_info": {
"module": "distilabel.pipeline.base",
"name": "_Batch",
},
}
],
"seq_no": 0,
"last_batch_received": [],
"type_info": {
Expand Down Expand Up @@ -2408,6 +2510,16 @@ def test_cache(self) -> None:
and batch_manager_step_path.is_file()
)

built_batches_dir = batch_manager_step_dir / "built_batches"
assert built_batches_dir.exists()

for batch in step.built_batches:
batch_path = (
built_batches_dir
/ f"batch_{batch.seq_no}_{batch.data_hash}.json"
)
assert batch_path.exists() and batch_path.is_file()

for buffered_step_name in step.data:
buffered_step_dir = batch_manager_step_dir / buffered_step_name
assert buffered_step_dir.exists() and buffered_step_dir.is_dir()
Expand Down Expand Up @@ -2458,6 +2570,31 @@ def test_load_from_cache(self) -> None:
}
],
},
"built_batches": [
{
"seq_no": 0,
"step_name": "step1",
"last_batch": False,
"data": [
[
{"a": 1},
{"a": 2},
{"a": 3},
{"a": 4},
{"a": 5},
]
],
"data_hash": "1234",
"size": 5,
"accumulated": False,
"batch_routed_to": [],
"created_from": {},
"type_info": {
"module": "distilabel.pipeline.base",
"name": "_Batch",
},
}
],
"seq_no": 0,
"last_batch_received": [],
"type_info": {
Expand Down Expand Up @@ -2500,6 +2637,31 @@ def test_load_from_cache(self) -> None:
}
],
},
"built_batches": [
{
"seq_no": 0,
"step_name": "step1",
"last_batch": False,
"data": [
[
{"a": 1},
{"a": 2},
{"a": 3},
{"a": 4},
{"a": 5},
]
],
"data_hash": "1234",
"size": 5,
"accumulated": False,
"batch_routed_to": [],
"created_from": {},
"type_info": {
"module": "distilabel.pipeline.base",
"name": "_Batch",
},
}
],
"seq_no": 0,
"last_batch_received": [],
"type_info": {
Expand Down

0 comments on commit 8e26025

Please sign in to comment.