Skip to content

Commit

Permalink
Add step to deduplicate records based on embeddings (#946)
Browse files Browse the repository at this point in the history
* Redirect import

* Add train_size argument to allow training indices

* Fix error when retrieving info from a dataset fails creating a step from the make_generator_step helper

* Add embedding dedup step

* Add unit and integration tests for embedding dedup

* Apply comments from code review
  • Loading branch information
plaguss authored Sep 6, 2024
1 parent 973e0fa commit de2bed0
Show file tree
Hide file tree
Showing 7 changed files with 450 additions and 7 deletions.
2 changes: 2 additions & 0 deletions src/distilabel/steps/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from distilabel.steps.deita import DeitaFiltering
from distilabel.steps.embeddings.embedding_generation import EmbeddingGeneration
from distilabel.steps.embeddings.nearest_neighbour import FaissNearestNeighbour
from distilabel.steps.filtering.embedding import EmbeddingDedup
from distilabel.steps.filtering.minhash import MinHashDedup
from distilabel.steps.formatting.conversation import ConversationTemplate
from distilabel.steps.formatting.dpo import (
Expand Down Expand Up @@ -79,6 +80,7 @@
"LoadDataFromDisk",
"LoadDataFromFileSystem",
"LoadDataFromHub",
"EmbeddingDedup",
"MinHashDedup",
"make_generator_step",
"PushToHub",
Expand Down
11 changes: 11 additions & 0 deletions src/distilabel/steps/embeddings/nearest_neighbour.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class FaissNearestNeighbour(GlobalStep):
search_batch_size: the number of rows to include in a search batch. The value can
be adjusted to maximize the resources usage or to avoid OOM issues. Defaults
to `50`.
train_size: If the index needs a training step, specifies how many vectors will be
used to train the index.
Runtime parameters:
- `device`: the CUDA device ID or a list of IDs to be used. If negative integer,
Expand All @@ -60,6 +62,8 @@ class FaissNearestNeighbour(GlobalStep):
- `search_batch_size`: the number of rows to include in a search batch. The value
can be adjusted to maximize the resources usage or to avoid OOM issues. Defaults
to `50`.
- `train_size`: If the index needs a training step, specifies how many vectors will
be used to train the index.
Input columns:
- embedding (`List[Union[float, int]]`): a sentence embedding.
Expand Down Expand Up @@ -148,6 +152,10 @@ class FaissNearestNeighbour(GlobalStep):
description="The number of rows to include in a search batch. The value can be adjusted"
" to maximize the resources usage or to avoid OOM issues.",
)
train_size: Optional[RuntimeParameter[int]] = Field(
default=None,
description="If the index needs a training step, specifies how many vectors will be used to train the index.",
)

def load(self) -> None:
super().load()
Expand Down Expand Up @@ -176,11 +184,14 @@ def _build_index(self, inputs: List[Dict[str, Any]]) -> Dataset:
The build `datasets.Dataset` with its `faiss` index.
"""
dataset = Dataset.from_list(inputs)
if self.train_size is not None and self.string_factory:
self._logger.info("🏋️‍♀️ Starting Faiss index training...")
dataset.add_faiss_index(
column="embedding",
device=self.device, # type: ignore
string_factory=self.string_factory,
metric_type=self.metric_type,
train_size=self.train_size,
)
return dataset

Expand Down
192 changes: 192 additions & 0 deletions src/distilabel/steps/filtering/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, List, Optional

import numpy as np
from pydantic import Field
from rich.progress import track
from typing_extensions import override

from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.base import GlobalStep, StepInput

if TYPE_CHECKING:
from distilabel.steps.typing import StepOutput


class EmbeddingDedup(GlobalStep):
"""Deduplicates text using embeddings.
`EmbeddingDedup` is a Step that detects near-duplicates in datasets, using
embeddings to compare the similarity between the texts. The typical workflow with this step
would include having a dataset with embeddings precomputed, and then (possibly using the
`FaissNearestNeighbour`) using the `nn_indices` and `nn_scores`, determine the texts that
are duplicate.
Attributes:
threshold: the threshold to consider 2 examples as duplicates.
It's dependent on the type of index that was used to generate the embeddings.
For example, if the embeddings were generated using cosine similarity, a threshold
of `0.9` would make all the texts with a cosine similarity above the value
duplicates. Higher values detect less duplicates in such an index, but that should
be taken into account when building it. Defaults to `0.9`.
Runtime Parameters:
- `threshold`: the threshold to consider 2 examples as duplicates.
Input columns:
- nn_indices (`List[int]`): a list containing the indices of the `k` nearest neighbours
in the inputs for the row.
- nn_scores (`List[float]`): a list containing the score or distance to each `k`
nearest neighbour in the inputs.
Output columns:
- keep_row_after_embedding_filtering (`bool`): boolean indicating if the piece `text` is
not a duplicate i.e. this text should be kept.
Categories:
- filtering
Examples:
Deduplicate a list of texts using embedding information:
```python
from distilabel.pipeline import Pipeline
from distilabel.steps import EmbeddingDedup
from distilabel.steps import LoadDataFromDicts
with Pipeline() as pipeline:
data = LoadDataFromDicts(
data=[
{
"persona": "A chemistry student or academic researcher interested in inorganic or physical chemistry, likely at an advanced undergraduate or graduate level, studying acid-base interactions and chemical bonding.",
"embedding": [
0.018477669046149742,
-0.03748236608841726,
0.001919870620352492,
0.024918478063770535,
0.02348063521315178,
0.0038251285566308375,
-0.01723884983037716,
0.02881971942372201,
],
"nn_indices": [0, 1],
"nn_scores": [
0.9164746999740601,
0.782106876373291,
],
},
{
"persona": "A music teacher or instructor focused on theoretical and practical piano lessons.",
"embedding": [
-0.0023464179614082125,
-0.07325472251663565,
-0.06058678419516501,
-0.02100326928586996,
-0.013462744792362657,
0.027368447064244242,
-0.003916070100455717,
0.01243614518480423,
],
"nn_indices": [0, 2],
"nn_scores": [
0.7552462220191956,
0.7261884808540344,
],
},
{
"persona": "A classical guitar teacher or instructor, likely with experience teaching beginners, who focuses on breaking down complex music notation into understandable steps for their students.",
"embedding": [
-0.01630817942328242,
-0.023760151552345232,
-0.014249650090627883,
-0.005713686451446624,
-0.016033059279131567,
0.0071440908501058786,
-0.05691099643425161,
0.01597412704817784,
],
"nn_indices": [1, 2],
"nn_scores": [
0.8107735514640808,
0.7172299027442932,
],
},
],
batch_size=batch_size,
)
# In general you should do something like this before the deduplication step, to obtain the
# `nn_indices` and `nn_scores`. In this case the embeddings are already normalized, so there's
# no need for it.
# nn = FaissNearestNeighbour(
# k=30,
# metric_type=faiss.METRIC_INNER_PRODUCT,
# search_batch_size=50,
# train_size=len(dataset), # The number of embeddings to use for training
# string_factory="IVF300_HNSW32,Flat" # To use an index (optional, maybe required for big datasets)
# )
# Read more about the `string_factory` here:
# https://github.com/facebookresearch/faiss/wiki/Guidelines-to-choose-an-index
embedding_dedup = EmbeddingDedup(
threshold=0.8,
input_batch_size=batch_size,
)
data >> embedding_dedup
if __name__ == "__main__":
distiset = pipeline.run(use_cache=False)
ds = distiset["default"]["train"]
# Filter out the duplicates
ds_dedup = ds.filter(lambda x: x["keep_row_after_embedding_filtering"])
```
"""

threshold: Optional[RuntimeParameter[float]] = Field(
default=0.9,
description="The threshold to consider 2 examples as duplicates. It's dependent "
"on the type of index that was used to generate the embeddings. For example, if "
"the embeddings were generated using cosine similarity, a threshold of `0.9` "
"would make all the texts with a cosine similarity above the value duplicates. "
"Higher values detect less duplicates in such an index, but that should be "
"taken into account when building it.",
)

@property
def inputs(self) -> List[str]:
return ["nn_scores", "nn_indices"]

@property
def outputs(self) -> List[str]:
return ["keep_row_after_embedding_filtering"]

@override
def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
rows_to_remove = set()

for input in track(inputs, description="Running Embedding deduplication..."):
input["keep_row_after_embedding_filtering"] = True
indices_scores = np.array(input["nn_scores"]) > self.threshold
indices = np.array(input["nn_indices"])[indices_scores]
if len(indices) > 0: # If there are any rows found over the threshold
rows_to_remove.update(list(indices))

# Remove duplicates and get the list of rows to remove
for idx in rows_to_remove:
inputs[idx]["keep_row_after_embedding_filtering"] = False

yield inputs
10 changes: 4 additions & 6 deletions src/distilabel/steps/generators/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,20 +243,18 @@ def _dataset_info(self) -> Dict[str, DatasetInfo]:
Returns:
The dataset information.
"""
repo_id = self.repo_id
config = self.config

try:
return get_dataset_infos(repo_id)
return get_dataset_infos(self.repo_id)
except Exception as e:
# The previous could fail in case of a internet connection issues.
# Assuming the dataset is already loaded and we can get the info from the loaded dataset, otherwise it will fail anyway.
self._logger.warning(
f"Failed to get dataset info from Hugging Face Hub, trying to get it loading the dataset. Error: {e}"
)
ds = load_dataset(repo_id, config=self.config, split=self.split)
if config:
return ds[config].info
ds = load_dataset(self.repo_id, config=self.config, split=self.split)
if self.config:
return ds[self.config].info
return ds.info


Expand Down
8 changes: 7 additions & 1 deletion src/distilabel/steps/generators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def make_generator_step(
input_mappings: Optional[Dict[str, str]] = None,
output_mappings: Optional[Dict[str, str]] = None,
resources: StepResources = StepResources(),
repo_id: str = "placeholder",
) -> "GeneratorStep":
"""Helper method to create a `GeneratorStep` from a dataset, to simplify
Expand All @@ -42,6 +43,10 @@ def make_generator_step(
input_mappings: Applies the same as any other step. Defaults to `None`.
output_mappings: Applies the same as any other step. Defaults to `None`.
resources: Applies the same as any other step. Defaults to `StepResources()`.
repo_id: The repository ID to use in the `LoadDataFromHub` step.
This shouldn't be necessary, but in case of error, the dataset will try to be loaded
using `load_dataset` internally. If that case happens, the `repo_id` will be used.
Defaults to `"placeholder"`.
Raises:
ValueError: If the format is different from the ones supported.
Expand Down Expand Up @@ -74,12 +79,13 @@ def make_generator_step(

loader = LoadDataFromHub(
pipeline=pipeline,
repo_id="placeholder_name",
repo_id=repo_id,
batch_size=batch_size,
input_mappings=input_mappings or {},
output_mappings=output_mappings or {},
resources=resources,
)
super(loader.__class__, loader).load() # Ensure the logger is loaded
loader._dataset = dataset
loader.num_examples = len(dataset)
loader._dataset_info = {"default": dataset.info}
Expand Down
Loading

0 comments on commit de2bed0

Please sign in to comment.