Skip to content

Commit

Permalink
Update PreferenceToArgilla
Browse files Browse the repository at this point in the history
  • Loading branch information
alvarobartt committed Jun 6, 2024
1 parent 2353768 commit 055a9de
Showing 1 changed file with 40 additions and 35 deletions.
75 changes: 40 additions & 35 deletions src/distilabel/steps/argilla/preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,20 @@
from typing_extensions import override

try:
import argilla as rg
import argilla_sdk as rg
except ImportError:
pass

from distilabel.steps.argilla.base import Argilla
from distilabel.steps.argilla.base import ArgillaBase
from distilabel.steps.base import StepInput

if TYPE_CHECKING:
from argilla import (
RatingQuestion,
SuggestionSchema,
TextField,
TextQuestion,
)
from argilla_sdk import RatingQuestion, Suggestion, TextField, TextQuestion

from distilabel.steps.typing import StepOutput


class PreferenceToArgilla(Argilla):
class PreferenceToArgilla(ArgillaBase):
"""Creates a preference dataset in Argilla.
Step that creates a dataset in Argilla during the load phase, and then pushes the input
Expand Down Expand Up @@ -97,16 +92,20 @@ def load(self) -> None:
self._ratings = self.input_mappings.get("ratings", "ratings")
self._rationales = self.input_mappings.get("rationales", "rationales")

if self._rg_dataset_exists():
_rg_dataset = rg.FeedbackDataset.from_argilla( # type: ignore
name=self.dataset_name,
workspace=self.dataset_workspace,
if self._dataset_exists_in_workspace:
_dataset = self._client.datasets( # type: ignore
name=self.dataset_name, # type: ignore
workspace=self._client.workspaces(name=self.dataset_workspace) # type: ignore
if self.dataset_workspace is not None
else None,
)

for field in _rg_dataset.fields:
for field in _dataset.fields:
if not isinstance(field, rg.TextField):
continue
if (
field.name
not in [self._id, self._instruction]
not in [self._id, self._instruction] # type: ignore
+ [
f"{self._generations}-{idx}"
for idx in range(self.num_generations)
Expand All @@ -116,23 +115,26 @@ def load(self) -> None:
raise ValueError(
f"The dataset {self.dataset_name} in the workspace {self.dataset_workspace} already exists,"
f" but contains at least a required field that is neither `{self._id}`, `{self._instruction}`,"
f" nor `{self._generations}`."
f" nor `{self._generations}` (one per generation starting from 0 up to {self.num_generations - 1})."
)

self._rg_dataset = _rg_dataset
self._dataset = _dataset
else:
_rg_dataset = rg.FeedbackDataset( # type: ignore
_settings = rg.Settings( # type: ignore
fields=[
rg.TextField(name=self._id, title=self._id), # type: ignore
rg.TextField(name=self._instruction, title=self._instruction), # type: ignore
*self._generation_fields(), # type: ignore
],
questions=self._rating_rationale_pairs(), # type: ignore
)
self._rg_dataset = _rg_dataset.push_to_argilla(
name=self.dataset_name, # type: ignore
_dataset = rg.Dataset( # type: ignore
name=self.dataset_name,
workspace=self.dataset_workspace,
settings=_settings,
client=self._client,
)
self._dataset = _dataset.create()

def _generation_fields(self) -> List["TextField"]:
"""Method to generate the fields for each of the generations."""
Expand Down Expand Up @@ -180,20 +182,23 @@ def inputs(self) -> List[str]:
provide the `ratings` and the `rationales` for the generations."""
return ["instruction", "generations"]

def _add_suggestions_if_any(
self, input: Dict[str, Any]
) -> List["SuggestionSchema"]:
"""Method to generate the suggestions for the `FeedbackRecord` based on the input."""
@property
def optional_inputs(self) -> List[str]:
"""The optional inputs for the step are the `ratings` and the `rationales` for the generations."""
return ["ratings", "rationales"]

def _add_suggestions_if_any(self, input: Dict[str, Any]) -> List["Suggestion"]:
"""Method to generate the suggestions for the `rg.Record` based on the input."""
# Since the `suggestions` i.e. answers to the `questions` are optional, will default to {}
suggestions = []
# If `ratings` is in `input`, then add those as suggestions
if self._ratings in input:
suggestions.extend(
[
{
"question_name": f"{self._generations}-{idx}-rating",
"value": rating,
}
rg.Suggestion( # type: ignore
value=rating,
question_name=f"{self._generations}-{idx}-rating",
)
for idx, rating in enumerate(input[self._ratings])
if rating is not None
and isinstance(rating, int)
Expand All @@ -204,10 +209,10 @@ def _add_suggestions_if_any(
if self._rationales in input:
suggestions.extend(
[
{
"question_name": f"{self._generations}-{idx}-rationale",
"value": rationale,
}
rg.Suggestion( # type: ignore
value=rationale,
question_name=f"{self._generations}-{idx}-rationale",
)
for idx, rationale in enumerate(input[self._rationales])
if rationale is not None and isinstance(rationale, str)
],
Expand All @@ -216,7 +221,7 @@ def _add_suggestions_if_any(

@override
def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
"""Creates and pushes the records as FeedbackRecords to the Argilla dataset.
"""Creates and pushes the records as `rg.Record`s to the Argilla dataset.
Args:
inputs: A list of Python dictionaries with the inputs of the task.
Expand All @@ -237,7 +242,7 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
}

records.append( # type: ignore
rg.FeedbackRecord( # type: ignore
rg.Record( # type: ignore
fields={
"id": instruction_id,
"instruction": input["instruction"], # type: ignore
Expand All @@ -246,5 +251,5 @@ def process(self, inputs: StepInput) -> "StepOutput": # type: ignore
suggestions=self._add_suggestions_if_any(input), # type: ignore
)
)
self._rg_dataset.add_records(records) # type: ignore
self._dataset.records.log(records) # type: ignore
yield inputs

0 comments on commit 055a9de

Please sign in to comment.