Skip to content

Commit

Permalink
Add cache at Step level (#766)
Browse files Browse the repository at this point in the history
* Add signature method for Serializable objects

* Update signature to only keep track of the step names and not it's internal info

* Refactor hash generation

* Add dummy batch manager from dag

* Update batch manager cache tests to start batch manager from a DAG

* Draft of integration tests for new caching

* Checkpoint draft

* Add cache directory location

* Add use_cache argument to Step for future use

* Change output names to keep track of them while debugging

* Make use of use_cache at the step level

* Add docstrings for internal batch manager arguments

* Remove path from add_batch method

* Move step caching to get_batch method in batch manager step

* Read batches from cached dir

* Set every step cache to False if the pipeline has the cache as False

* Comment for the batch manager

* Move back to caching from add_step

* Checkpoint current status

* Add use_cache on step

* If there's previous data saved, concatenate the content of the parquet files

* Only read the distiset from cache if all the steps are the same, otherwise overwrite

* Add changes to make loading a new and modified step feasible

* Set use cache to True by default

* Move logic of registering the batches to BasePipeline._register_batch to do it before calling _manage_batch_flows

* Avoid reading parquet file from cache when any of the steps has use_cach=False

* Add is_convergence method to DAG and cleanup batch_manager

* Add integration tests for the new caching mechanism

* Update unit tests related to register_batch

* Fix signature serialization case of void list

* Add use_cache to argilla tests

* Fix tests related to use_cache

* Fix tests

* Remove undefined object input

* Add `_invalidate_steps_cache_if_required` method

* Initial work for loading batches from `batch_manager_data` directory

* Draft cache updates

* Update pipeline signature

* Add signature mixin from other PR

* Moved pipeline cache to executions folder with different data per pipeline

* Testing new updates to read from cache

* Checkpoint with loading working while adding new steps

* Point of control

* Fix not all the batches where being saved

* Sort batches after loaded

* Fix `load_from_cache` to load batches from `steps_data` directory
correctly

* Update test

* Add `step_has_finished` method

* Update invalidate cache function

* Update integration caching tests

* Refactor to extract logic to methods

* Refactor to remove `cached_data_dir`

* Update stages message

* Refactor `invalidate_cache_for` method

* Fix `_BatchManager` unit tests

* Update to not serialize `exclude_from_signature` attribute

* Fix pipeline unit tests

* Remove write buffer data if `use_cache=False`

* Fix offline batch generation attributes were being not ignored by
signature

* Fix print test

* Fix routing batch function

---------

Co-authored-by: Gabriel Martín Blázquez <[email protected]>
  • Loading branch information
plaguss and gabrielmbmb authored Oct 7, 2024
1 parent e027f99 commit ebab004
Show file tree
Hide file tree
Showing 28 changed files with 1,509 additions and 643 deletions.
83 changes: 83 additions & 0 deletions src/distilabel/mixins/signature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import hashlib
from typing import TYPE_CHECKING, Any, List, Set

from pydantic import BaseModel, Field

from distilabel.utils.serialization import TYPE_INFO_KEY

if TYPE_CHECKING:
pass

# Add here the name of the attributes that shouldn't be used to generate the signature.
# Attributes from a `BaseModel` that is an attribute from the root class must be prefixed
# with the name of the attribute followed by an underscore. For example, if the attribute
# `jobs_ids` is an attribute from the `llm` attribute of the root class it should be added
# as `llm_jobs_ids`.
_EXCLUDE_FROM_SIGNATURE_DEFAULTS = {
TYPE_INFO_KEY,
"disable_cuda_device_placement",
"input_batch_size",
"gpu_memory_utilization",
"resources",
"exclude_from_signature",
"llm_jobs_ids",
"llm_offline_batch_generation_block_until_done",
}


class SignatureMixin(BaseModel):
"""Mixin for creating a signature (for cache) of the class.
Attributes:
exclude_from_signature: list of attributes to exclude from the signature.
"""

exclude_from_signature: Set[str] = Field(
default=_EXCLUDE_FROM_SIGNATURE_DEFAULTS, exclude=True
)

@property
def signature(self) -> str:
"""Makes a signature (hash) of the class, using its attributes.
Returns:
signature of the class.
"""

def flatten_dump(d: Any, parent_key: str = "", sep: str = "_") -> List:
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, dict):
items.extend(flatten_dump(v, new_key, sep=sep))
elif isinstance(v, list):
if len(v) == 0:
items.append((new_key, ""))
elif isinstance(v[0], str):
items.append((new_key, "-".join(v)))
else:
for i, x in enumerate(v):
items.extend(flatten_dump(x, f"{new_key}-{i}", sep=sep))
elif new_key not in self.exclude_from_signature:
items.append((new_key, v))
return items

info = []
for name, value in flatten_dump(self.dump()):
info.append(f"{name}-{str(value)}")

return hashlib.sha1("-".join(info).encode()).hexdigest()
16 changes: 16 additions & 0 deletions src/distilabel/pipeline/_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from distilabel.constants import (
CONVERGENCE_STEP_ATTR_NAME,
RECEIVES_ROUTED_BATCHES_ATTR_NAME,
ROUTING_BATCH_FUNCTION_ATTR_NAME,
STEP_ATTR_NAME,
)
Expand Down Expand Up @@ -253,6 +254,21 @@ def is_step_in_trophic_level(self, step_name: str, trophic_level: int) -> bool:
"""
return self.get_step_trophic_level(step_name) == trophic_level

def is_convergence_step(self, step_name: str) -> bool:
"""Checks if a given step is a convegence step.
Args:
step_name: Name of the step to check if a convergence step.
Returns:
True if it is, False otherwise.
"""
predecessors = list(self.get_step_predecessors(step_name))
return all(
self.get_step(predecessor).get(RECEIVES_ROUTED_BATCHES_ATTR_NAME, False)
for predecessor in predecessors
)

def step_in_last_trophic_level(self, step_name: str) -> bool:
"""Checks if a step is in the last trophic level.
Expand Down
Loading

0 comments on commit ebab004

Please sign in to comment.