Skip to content

Commit

Permalink
Update argilla integration to use argilla_sdk v2 (#705)
Browse files Browse the repository at this point in the history
* Update `_Argilla` base and `TextGenerationToArgilla`

* Fix `_dataset.records.log` and rename to `ArgillaBase`

Co-authored-by: Ben Burtenshaw <[email protected]>

* Update `TextGenerationToArgilla` subclass inheritance

* Remove unused `logger.info` message

* Update `PreferenceToArgilla`

* Update `argilla` extra to install `argilla_sdk`

For the moment it's being installed as `pip install git+https://github.com/argilla-io/argilla-python.git@main`

* Add `ArgillaBase` and subclasses unit tests

* Install `argilla_sdk` from source and add `ipython`

* upgrade argilla dep to latest rc

* udate code with latest changes

* chore: remove unnecessary workspace definition

* fix: wrong argilla module import

* Update docstrings

* Fix lint

* Add check for `api_url` and `api_key`

* Fix unit tests

* Fix unit tests

* Update argilla dependency version

---------

Co-authored-by: Ben Burtenshaw <[email protected]>
Co-authored-by: Francisco Aranda <[email protected]>
Co-authored-by: Gabriel Martín Blázquez <[email protected]>
  • Loading branch information
4 people authored Jul 30, 2024
1 parent be61d20 commit 18dc02c
Show file tree
Hide file tree
Showing 8 changed files with 206 additions and 150 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ tests = [

# Optional LLMs, integrations, etc
anthropic = ["anthropic >= 0.20.0"]
argilla = ["argilla >= 1.29.0"]
argilla = ["argilla >= 2.0.0", "ipython"]
cohere = ["cohere >= 5.2.0"]
groq = ["groq >= 0.4.1"]
hf-inference-endpoints = ["huggingface_hub >= 0.22.0"]
Expand Down
72 changes: 43 additions & 29 deletions src/distilabel/steps/argilla/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib.util
import os
import warnings
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, List, Optional

Expand All @@ -28,15 +28,16 @@
from distilabel.steps.base import Step, StepInput

if TYPE_CHECKING:
from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset
from argilla import Argilla, Dataset

from distilabel.steps.typing import StepOutput


_ARGILLA_API_URL_ENV_VAR_NAME = "ARGILLA_API_URL"
_ARGILLA_API_KEY_ENV_VAR_NAME = "ARGILLA_API_KEY"


class Argilla(Step, ABC):
class ArgillaBase(Step, ABC):
"""Abstract step that provides a class to subclass from, that contains the boilerplate code
required to interact with Argilla, as well as some extra validations on top of it. It also defines
the abstract methods that need to be implemented in order to add a new dataset type as a step.
Expand Down Expand Up @@ -70,55 +71,61 @@ class Argilla(Step, ABC):
)
dataset_workspace: Optional[RuntimeParameter[str]] = Field(
default=None,
description="The workspace where the dataset will be created in Argilla. Defaults"
description="The workspace where the dataset will be created in Argilla. Defaults "
"to `None` which means it will be created in the default workspace.",
)

api_url: Optional[RuntimeParameter[str]] = Field(
default_factory=lambda: os.getenv("ARGILLA_API_URL"),
default_factory=lambda: os.getenv(_ARGILLA_API_URL_ENV_VAR_NAME),
description="The base URL to use for the Argilla API requests.",
)
api_key: Optional[RuntimeParameter[SecretStr]] = Field(
default_factory=lambda: os.getenv(_ARGILLA_API_KEY_ENV_VAR_NAME),
description="The API key to authenticate the requests to the Argilla API.",
)

_rg_dataset: Optional["RemoteFeedbackDataset"] = PrivateAttr(...)
_client: Optional["Argilla"] = PrivateAttr(...)
_dataset: Optional["Dataset"] = PrivateAttr(...)

def model_post_init(self, __context: Any) -> None:
"""Checks that the Argilla Python SDK is installed, and then filters the Argilla warnings."""
super().model_post_init(__context)

try:
import argilla as rg # noqa
except ImportError as ie:
if importlib.util.find_spec("argilla") is None:
raise ImportError(
"Argilla is not installed. Please install it using `pip install argilla`."
) from ie

warnings.filterwarnings("ignore")
"Argilla is not installed. Please install it using `pip install argilla"
" --upgrade`."
)

def _rg_init(self) -> None:
def _client_init(self) -> None:
"""Initializes the Argilla API client with the provided `api_url` and `api_key`."""
try:
if "hf.space" in self.api_url and "HF_TOKEN" in os.environ:
headers = {"Authorization": f"Bearer {os.environ['HF_TOKEN']}"}
else:
headers = None
rg.init(
self._client = rg.Argilla( # type: ignore
api_url=self.api_url,
api_key=self.api_key.get_secret_value(),
extra_headers=headers,
) # type: ignore
api_key=self.api_key.get_secret_value(), # type: ignore
headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"}
if isinstance(self.api_url, str)
and "hf.space" in self.api_url
and "HF_TOKEN" in os.environ
else {},
)
except Exception as e:
raise ValueError(f"Failed to initialize the Argilla API: {e}") from e

def _rg_dataset_exists(self) -> bool:
"""Checks if the dataset already exists in Argilla."""
return self.dataset_name in [
dataset.name
for dataset in rg.FeedbackDataset.list(workspace=self.dataset_workspace) # type: ignore
]
@property
def _dataset_exists_in_workspace(self) -> bool:
"""Checks if the dataset already exists in Argilla in the provided workspace if any.
Returns:
`True` if the dataset exists, `False` otherwise.
"""
return (
self._client.datasets( # type: ignore
name=self.dataset_name, # type: ignore
workspace=self.dataset_workspace,
)
is not None
)

@property
def outputs(self) -> List[str]:
Expand All @@ -133,7 +140,14 @@ def load(self) -> None:
"""
super().load()

self._rg_init()
if self.api_url is None or self.api_key is None:
raise ValueError(
"`Argilla` step requires the `api_url` and `api_key` to be provided. Please,"
" provide those at step instantiation, via environment variables `ARGILLA_API_URL`"
" and `ARGILLA_API_KEY`, or as `Step` runtime parameters via `pipeline.run(parameters={...})`."
)

self._client_init()

@property
@abstractmethod
Expand Down
93 changes: 55 additions & 38 deletions src/distilabel/steps/argilla/preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,16 @@
except ImportError:
pass

from distilabel.steps.argilla.base import Argilla
from distilabel.steps.argilla.base import ArgillaBase
from distilabel.steps.base import StepInput

if TYPE_CHECKING:
from argilla import (
RatingQuestion,
SuggestionSchema,
TextField,
TextQuestion,
)
from argilla import RatingQuestion, Suggestion, TextField, TextQuestion

from distilabel.steps.typing import StepOutput


class PreferenceToArgilla(Argilla):
class PreferenceToArgilla(ArgillaBase):
"""Creates a preference dataset in Argilla.
Step that creates a dataset in Argilla during the load phase, and then pushes the input
Expand Down Expand Up @@ -153,45 +148,55 @@ def load(self) -> None:
self._ratings = self.input_mappings.get("ratings", "ratings")
self._rationales = self.input_mappings.get("rationales", "rationales")

if self._rg_dataset_exists():
_rg_dataset = rg.FeedbackDataset.from_argilla( # type: ignore
name=self.dataset_name,
workspace=self.dataset_workspace,
if self._dataset_exists_in_workspace:
_dataset = self._client.datasets( # type: ignore
name=self.dataset_name, # type: ignore
workspace=self.dataset_workspace, # type: ignore
)

for field in _rg_dataset.fields:
for field in _dataset.fields:
if not isinstance(field, rg.TextField):
continue
if (
field.name
not in [self._id, self._instruction]
not in [self._id, self._instruction] # type: ignore
+ [
f"{self._generations}-{idx}"
for idx in range(self.num_generations)
]
and field.required
):
raise ValueError(
f"The dataset {self.dataset_name} in the workspace {self.dataset_workspace} already exists,"
f" but contains at least a required field that is neither `{self._id}`, `{self._instruction}`,"
f" nor `{self._generations}`."
f"The dataset '{self.dataset_name}' in the workspace '{self.dataset_workspace}'"
f" already exists, but contains at least a required field that is"
f" neither `{self._id}`, `{self._instruction}`, nor `{self._generations}`"
f" (one per generation starting from 0 up to {self.num_generations - 1})."
)

self._rg_dataset = _rg_dataset
self._dataset = _dataset
else:
_rg_dataset = rg.FeedbackDataset( # type: ignore
_settings = rg.Settings( # type: ignore
fields=[
rg.TextField(name=self._id, title=self._id), # type: ignore
rg.TextField(name=self._instruction, title=self._instruction), # type: ignore
*self._generation_fields(), # type: ignore
],
questions=self._rating_rationale_pairs(), # type: ignore
)
self._rg_dataset = _rg_dataset.push_to_argilla(
name=self.dataset_name, # type: ignore
_dataset = rg.Dataset( # type: ignore
name=self.dataset_name,
workspace=self.dataset_workspace,
settings=_settings,
client=self._client,
)
self._dataset = _dataset.create()

def _generation_fields(self) -> List["TextField"]:
"""Method to generate the fields for each of the generations."""
"""Method to generate the fields for each of the generations.
Returns:
A list containing `TextField`s for each text generation.
"""
return [
rg.TextField( # type: ignore
name=f"{self._generations}-{idx}",
Expand All @@ -204,7 +209,12 @@ def _generation_fields(self) -> List["TextField"]:
def _rating_rationale_pairs(
self,
) -> List[Union["RatingQuestion", "TextQuestion"]]:
"""Method to generate the rating and rationale questions for each of the generations."""
"""Method to generate the rating and rationale questions for each of the generations.
Returns:
A list of questions containing a `RatingQuestion` and `TextQuestion` pair for
each text generation.
"""
questions = []
for idx in range(self.num_generations):
questions.extend(
Expand Down Expand Up @@ -236,20 +246,27 @@ def inputs(self) -> List[str]:
provide the `ratings` and the `rationales` for the generations."""
return ["instruction", "generations"]

def _add_suggestions_if_any(
self, input: Dict[str, Any]
) -> List["SuggestionSchema"]:
"""Method to generate the suggestions for the `FeedbackRecord` based on the input."""
@property
def optional_inputs(self) -> List[str]:
"""The optional inputs for the step are the `ratings` and the `rationales` for the generations."""
return ["ratings", "rationales"]

def _add_suggestions_if_any(self, input: Dict[str, Any]) -> List["Suggestion"]:
"""Method to generate the suggestions for the `rg.Record` based on the input.
Returns:
A list of `Suggestion`s for the rating and rationales questions.
"""
# Since the `suggestions` i.e. answers to the `questions` are optional, will default to {}
suggestions = []
# If `ratings` is in `input`, then add those as suggestions
if self._ratings in input:
suggestions.extend(
[
{
"question_name": f"{self._generations}-{idx}-rating",
"value": rating,
}
rg.Suggestion( # type: ignore
value=rating,
question_name=f"{self._generations}-{idx}-rating",
)
for idx, rating in enumerate(input[self._ratings])
if rating is not None
and isinstance(rating, int)
Expand All @@ -260,10 +277,10 @@ def _add_suggestions_if_any(
if self._rationales in input:
suggestions.extend(
[
{
"question_name": f"{self._generations}-{idx}-rationale",
"value": rationale,
}
rg.Suggestion( # type: ignore
value=rationale,
question_name=f"{self._generations}-{idx}-rationale",
)
for idx, rationale in enumerate(input[self._rationales])
if rationale is not None and isinstance(rationale, str)
],
Expand All @@ -272,7 +289,7 @@ def _add_suggestions_if_any(

@override
def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
"""Creates and pushes the records as FeedbackRecords to the Argilla dataset.
"""Creates and pushes the records as `rg.Record`s to the Argilla dataset.
Args:
inputs: A list of Python dictionaries with the inputs of the task.
Expand All @@ -293,7 +310,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
}

records.append( # type: ignore
rg.FeedbackRecord( # type: ignore
rg.Record( # type: ignore
fields={
"id": instruction_id,
"instruction": input["instruction"], # type: ignore
Expand All @@ -302,5 +319,5 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
suggestions=self._add_suggestions_if_any(input), # type: ignore
)
)
self._rg_dataset.add_records(records) # type: ignore
self._dataset.records.log(records) # type: ignore
yield inputs
Loading

0 comments on commit 18dc02c

Please sign in to comment.