Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update argilla integration to use argilla_sdk v2 #705

Merged
merged 22 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
9662146
Update `_Argilla` base and `TextGenerationToArgilla`
alvarobartt Jun 6, 2024
ef189f8
Fix `_dataset.records.log` and rename to `ArgillaBase`
alvarobartt Jun 6, 2024
c83de4c
Update `TextGenerationToArgilla` subclass inheritance
alvarobartt Jun 6, 2024
2353768
Remove unused `logger.info` message
alvarobartt Jun 6, 2024
055a9de
Update `PreferenceToArgilla`
alvarobartt Jun 6, 2024
18761fb
Update `argilla` extra to install `argilla_sdk`
alvarobartt Jun 6, 2024
7d0f07d
Add `ArgillaBase` and subclasses unit tests
alvarobartt Jun 7, 2024
a97e310
Merge branch 'develop' into argilla-2.0
alvarobartt Jun 7, 2024
d77dd11
Install `argilla_sdk` from source and add `ipython`
alvarobartt Jun 10, 2024
d6f7131
Merge branch 'develop' into argilla-2.0
alvarobartt Jun 12, 2024
7d55576
upgrade argilla dep to latest rc
frascuchon Jul 17, 2024
78ca5f7
udate code with latest changes
frascuchon Jul 17, 2024
c9fc2a5
chore: remove unnecessary workspace definition
frascuchon Jul 17, 2024
06a3610
fix: wrong argilla module import
frascuchon Jul 17, 2024
58a2e8c
Merge branch 'develop' into argilla-2.0
gabrielmbmb Jul 30, 2024
5c1ce95
Update docstrings
gabrielmbmb Jul 30, 2024
1e16e38
Fix lint
gabrielmbmb Jul 30, 2024
20b92ab
Add check for `api_url` and `api_key`
gabrielmbmb Jul 30, 2024
ba13431
Fix unit tests
gabrielmbmb Jul 30, 2024
d088510
Fix unit tests
gabrielmbmb Jul 30, 2024
6e13f3a
Merge branch 'argilla-2.0' of https://github.com/argilla-io/rlxf into…
gabrielmbmb Jul 30, 2024
b0a6b71
Update argilla dependency version
gabrielmbmb Jul 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 41 additions & 23 deletions src/distilabel/steps/argilla/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,24 @@
from pydantic import Field, PrivateAttr, SecretStr

try:
import argilla as rg
import argilla_sdk as rg
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
except ImportError:
pass

from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.steps.base import Step, StepInput

if TYPE_CHECKING:
from argilla.client.feedback.dataset.remote.dataset import RemoteFeedbackDataset
from argilla_sdk import Argilla, Dataset

from distilabel.steps.typing import StepOutput


_ARGILLA_API_URL_ENV_VAR_NAME = "ARGILLA_API_URL"
_ARGILLA_API_KEY_ENV_VAR_NAME = "ARGILLA_API_KEY"


class Argilla(Step, ABC):
class _Argilla(Step, ABC):
alvarobartt marked this conversation as resolved.
Show resolved Hide resolved
"""Abstract step that provides a class to subclass from, that contains the boilerplate code
required to interact with Argilla, as well as some extra validations on top of it. It also defines
the abstract methods that need to be implemented in order to add a new dataset type as a step.
Expand Down Expand Up @@ -75,50 +76,67 @@ class Argilla(Step, ABC):
)

api_url: Optional[RuntimeParameter[str]] = Field(
default_factory=lambda: os.getenv("ARGILLA_API_URL"),
default_factory=lambda: os.getenv(_ARGILLA_API_URL_ENV_VAR_NAME),
description="The base URL to use for the Argilla API requests.",
)
api_key: Optional[RuntimeParameter[SecretStr]] = Field(
default_factory=lambda: os.getenv(_ARGILLA_API_KEY_ENV_VAR_NAME),
description="The API key to authenticate the requests to the Argilla API.",
)

_rg_dataset: Optional["RemoteFeedbackDataset"] = PrivateAttr(...)
_client: Optional["Argilla"] = PrivateAttr(...)
_dataset: Optional["Dataset"] = PrivateAttr(...)

def model_post_init(self, __context: Any) -> None:
"""Checks that the Argilla Python SDK is installed, and then filters the Argilla warnings."""
super().model_post_init(__context)

try:
import argilla as rg # noqa
import argilla_sdk as rg # noqa
except ImportError as ie:
raise ImportError(
"Argilla is not installed. Please install it using `pip install argilla`."
"Argilla is not installed. Please install it using `pip install argilla_sdk --upgrade`."
) from ie

warnings.filterwarnings("ignore")

def _rg_init(self) -> None:
def _client_init(self) -> None:
"""Initializes the Argilla API client with the provided `api_url` and `api_key`."""
try:
if "hf.space" in self.api_url and "HF_TOKEN" in os.environ:
headers = {"Authorization": f"Bearer {os.environ['HF_TOKEN']}"}
else:
headers = None
rg.init(
self._client = rg.Argilla( # type: ignore
api_url=self.api_url,
api_key=self.api_key.get_secret_value(),
extra_headers=headers,
) # type: ignore
api_key=self.api_key.get_secret_value(), # type: ignore
headers={"Authorization": f"Bearer {os.environ['HF_TOKEN']}"}
if isinstance(self.api_url, str)
and "hf.space" in self.api_url
and "HF_TOKEN" in os.environ
else {},
)
except Exception as e:
raise ValueError(f"Failed to initialize the Argilla API: {e}") from e

def _rg_dataset_exists(self) -> bool:
"""Checks if the dataset already exists in Argilla."""
return self.dataset_name in [
dataset.name
for dataset in rg.FeedbackDataset.list(workspace=self.dataset_workspace) # type: ignore
]
@property
def _dataset_exists_in_workspace(self) -> bool:
"""Checks if the dataset already exists in Argilla in the provided workspace if any."""
return (
True
if self._client.datasets( # type: ignore
name=self.dataset_name, # type: ignore
workspace=self._client.workspaces(name=self.dataset_workspace) # type: ignore
if self.dataset_workspace is not None
else None,
).exists()
is not None
else False
)
alvarobartt marked this conversation as resolved.
Show resolved Hide resolved
# Alternative way but the above should work
# return self.dataset_name in (
# [
# dataset.name for dataset in self._client.workspaces(self.dataset_workspace) # type: ignore
# ]
# if self.dataset_workspace is not None
# else [dataset.name for dataset in self._client.datasets.list()] # type: ignore
# )

@property
def outputs(self) -> List[str]:
Expand All @@ -133,7 +151,7 @@ def load(self) -> None:
"""
super().load()

self._rg_init()
self._client_init()

@property
@abstractmethod
Expand Down
42 changes: 25 additions & 17 deletions src/distilabel/steps/argilla/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,18 @@
from typing_extensions import override

try:
import argilla as rg
import argilla_sdk as rg
except ImportError:
pass

from distilabel.steps.argilla.base import Argilla
from distilabel.steps.argilla.base import _Argilla
from distilabel.steps.base import StepInput

if TYPE_CHECKING:
from distilabel.steps.typing import StepOutput


class TextGenerationToArgilla(Argilla):
class TextGenerationToArgilla(_Argilla):
"""Creates a text generation dataset in Argilla.

`Step` that creates a dataset in Argilla during the load phase, and then pushes the input
Expand Down Expand Up @@ -74,26 +74,30 @@ def load(self) -> None:
self._instruction = self.input_mappings.get("instruction", "instruction")
self._generation = self.input_mappings.get("generation", "generation")

if self._rg_dataset_exists():
_rg_dataset = rg.FeedbackDataset.from_argilla( # type: ignore
name=self.dataset_name,
workspace=self.dataset_workspace,
if self._dataset_exists_in_workspace:
_dataset = self._client.datasets( # type: ignore
name=self.dataset_name, # type: ignore
workspace=self._client.workspaces(name=self.dataset_workspace) # type: ignore
if self.dataset_workspace is not None
else None,
)

for field in _rg_dataset.fields:
for field in _dataset.fields:
if not isinstance(field, rg.TextField): # type: ignore
continue
if (
field.name not in [self._id, self._instruction, self._generation]
and field.required
):
raise ValueError(
f"The dataset {self.dataset_name} in the workspace {self.dataset_workspace} already exists,"
f" but contains at least a required field that is neither `{self._id}`, `{self._instruction}`"
f", nor `{self._generation}`."
f", nor `{self._generation}`, so it cannot be reused for this dataset."
)

self._rg_dataset = _rg_dataset
self._dataset = _dataset
else:
_rg_dataset = rg.FeedbackDataset( # type: ignore
_settings = rg.Settings( # type: ignore
fields=[
rg.TextField(name=self._id, title=self._id), # type: ignore
rg.TextField(name=self._instruction, title=self._instruction), # type: ignore
Expand All @@ -103,14 +107,18 @@ def load(self) -> None:
rg.LabelQuestion( # type: ignore
name="quality",
title=f"What's the quality of the {self._generation} for the given {self._instruction}?",
labels={"bad": "👎", "good": "👍"},
labels={"bad": "👎", "good": "👍"}, # type: ignore
)
],
)
self._rg_dataset = _rg_dataset.push_to_argilla(
name=self.dataset_name, # type: ignore
_dataset = rg.Dataset( # type: ignore
name=self.dataset_name,
workspace=self.dataset_workspace,
settings=_settings,
client=self._client,
)
self._logger.info(f"Creating the dataset {self.dataset_name} in Argilla.")
self._dataset = _dataset.create()

@property
def inputs(self) -> List[str]:
Expand Down Expand Up @@ -151,13 +159,13 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
generations_set.add(generation)

records.append(
rg.FeedbackRecord( # type: ignore
rg.Record( # type: ignore
fields={
self._id: instruction_id,
self._instruction: input["instruction"],
self._generation: generation,
},
)
),
)
self._rg_dataset.add_records(records) # type: ignore
self._dataset.log(records) # type: ignore
alvarobartt marked this conversation as resolved.
Show resolved Hide resolved
yield inputs
Loading