From 055a9dee31a61c3ac6a4a40f8f4275d066f60fad Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Thu, 6 Jun 2024 18:01:09 +0200 Subject: [PATCH] Update `PreferenceToArgilla` --- src/distilabel/steps/argilla/preference.py | 75 ++++++++++++---------- 1 file changed, 40 insertions(+), 35 deletions(-) diff --git a/src/distilabel/steps/argilla/preference.py b/src/distilabel/steps/argilla/preference.py index 7a1a5f7a15..1bba793813 100644 --- a/src/distilabel/steps/argilla/preference.py +++ b/src/distilabel/steps/argilla/preference.py @@ -19,25 +19,20 @@ from typing_extensions import override try: - import argilla as rg + import argilla_sdk as rg except ImportError: pass -from distilabel.steps.argilla.base import Argilla +from distilabel.steps.argilla.base import ArgillaBase from distilabel.steps.base import StepInput if TYPE_CHECKING: - from argilla import ( - RatingQuestion, - SuggestionSchema, - TextField, - TextQuestion, - ) + from argilla_sdk import RatingQuestion, Suggestion, TextField, TextQuestion from distilabel.steps.typing import StepOutput -class PreferenceToArgilla(Argilla): +class PreferenceToArgilla(ArgillaBase): """Creates a preference dataset in Argilla. Step that creates a dataset in Argilla during the load phase, and then pushes the input @@ -97,16 +92,20 @@ def load(self) -> None: self._ratings = self.input_mappings.get("ratings", "ratings") self._rationales = self.input_mappings.get("rationales", "rationales") - if self._rg_dataset_exists(): - _rg_dataset = rg.FeedbackDataset.from_argilla( # type: ignore - name=self.dataset_name, - workspace=self.dataset_workspace, + if self._dataset_exists_in_workspace: + _dataset = self._client.datasets( # type: ignore + name=self.dataset_name, # type: ignore + workspace=self._client.workspaces(name=self.dataset_workspace) # type: ignore + if self.dataset_workspace is not None + else None, ) - for field in _rg_dataset.fields: + for field in _dataset.fields: + if not isinstance(field, rg.TextField): + continue if ( field.name - not in [self._id, self._instruction] + not in [self._id, self._instruction] # type: ignore + [ f"{self._generations}-{idx}" for idx in range(self.num_generations) @@ -116,12 +115,12 @@ def load(self) -> None: raise ValueError( f"The dataset {self.dataset_name} in the workspace {self.dataset_workspace} already exists," f" but contains at least a required field that is neither `{self._id}`, `{self._instruction}`," - f" nor `{self._generations}`." + f" nor `{self._generations}` (one per generation starting from 0 up to {self.num_generations - 1})." ) - self._rg_dataset = _rg_dataset + self._dataset = _dataset else: - _rg_dataset = rg.FeedbackDataset( # type: ignore + _settings = rg.Settings( # type: ignore fields=[ rg.TextField(name=self._id, title=self._id), # type: ignore rg.TextField(name=self._instruction, title=self._instruction), # type: ignore @@ -129,10 +128,13 @@ def load(self) -> None: ], questions=self._rating_rationale_pairs(), # type: ignore ) - self._rg_dataset = _rg_dataset.push_to_argilla( - name=self.dataset_name, # type: ignore + _dataset = rg.Dataset( # type: ignore + name=self.dataset_name, workspace=self.dataset_workspace, + settings=_settings, + client=self._client, ) + self._dataset = _dataset.create() def _generation_fields(self) -> List["TextField"]: """Method to generate the fields for each of the generations.""" @@ -180,20 +182,23 @@ def inputs(self) -> List[str]: provide the `ratings` and the `rationales` for the generations.""" return ["instruction", "generations"] - def _add_suggestions_if_any( - self, input: Dict[str, Any] - ) -> List["SuggestionSchema"]: - """Method to generate the suggestions for the `FeedbackRecord` based on the input.""" + @property + def optional_inputs(self) -> List[str]: + """The optional inputs for the step are the `ratings` and the `rationales` for the generations.""" + return ["ratings", "rationales"] + + def _add_suggestions_if_any(self, input: Dict[str, Any]) -> List["Suggestion"]: + """Method to generate the suggestions for the `rg.Record` based on the input.""" # Since the `suggestions` i.e. answers to the `questions` are optional, will default to {} suggestions = [] # If `ratings` is in `input`, then add those as suggestions if self._ratings in input: suggestions.extend( [ - { - "question_name": f"{self._generations}-{idx}-rating", - "value": rating, - } + rg.Suggestion( # type: ignore + value=rating, + question_name=f"{self._generations}-{idx}-rating", + ) for idx, rating in enumerate(input[self._ratings]) if rating is not None and isinstance(rating, int) @@ -204,10 +209,10 @@ def _add_suggestions_if_any( if self._rationales in input: suggestions.extend( [ - { - "question_name": f"{self._generations}-{idx}-rationale", - "value": rationale, - } + rg.Suggestion( # type: ignore + value=rationale, + question_name=f"{self._generations}-{idx}-rationale", + ) for idx, rationale in enumerate(input[self._rationales]) if rationale is not None and isinstance(rationale, str) ], @@ -216,7 +221,7 @@ def _add_suggestions_if_any( @override def process(self, inputs: StepInput) -> "StepOutput": # type: ignore - """Creates and pushes the records as FeedbackRecords to the Argilla dataset. + """Creates and pushes the records as `rg.Record`s to the Argilla dataset. Args: inputs: A list of Python dictionaries with the inputs of the task. @@ -237,7 +242,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore } records.append( # type: ignore - rg.FeedbackRecord( # type: ignore + rg.Record( # type: ignore fields={ "id": instruction_id, "instruction": input["instruction"], # type: ignore @@ -246,5 +251,5 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore suggestions=self._add_suggestions_if_any(input), # type: ignore ) ) - self._rg_dataset.add_records(records) # type: ignore + self._dataset.records.log(records) # type: ignore yield inputs