From 983576045fa6c0da4e71fc997a57b956c62f0b56 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome Date: Wed, 27 Dec 2023 14:03:10 +0100 Subject: [PATCH] Fix `UltraCMTask` scoring range and align `argilla` imports (#201) * Override `to_argilla_dataset` in `UltraCMTask` to use 1-10 scores * Align `argilla` imports across codebase --- src/distilabel/tasks/base.py | 3 +-- src/distilabel/tasks/critique/ultracm.py | 21 ++++++++++++++++++- src/distilabel/tasks/text_generation/base.py | 3 +-- .../tasks/text_generation/self_instruct.py | 3 +-- 4 files changed, 23 insertions(+), 7 deletions(-) diff --git a/src/distilabel/tasks/base.py b/src/distilabel/tasks/base.py index 02b6f4f11c..c334ab91dc 100644 --- a/src/distilabel/tasks/base.py +++ b/src/distilabel/tasks/base.py @@ -27,8 +27,7 @@ from distilabel.tasks.prompt import Prompt if TYPE_CHECKING: - from argilla import FeedbackDataset - from argilla.client.feedback.schemas.records import FeedbackRecord + from argilla import FeedbackDataset, FeedbackRecord def get_template(template_name: str) -> str: diff --git a/src/distilabel/tasks/critique/ultracm.py b/src/distilabel/tasks/critique/ultracm.py index 098d51e888..a4fd751e14 100644 --- a/src/distilabel/tasks/critique/ultracm.py +++ b/src/distilabel/tasks/critique/ultracm.py @@ -14,12 +14,15 @@ import re from dataclasses import dataclass -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional from distilabel.tasks.base import get_template from distilabel.tasks.critique.base import CritiqueTask, CritiqueTaskOutput from distilabel.tasks.prompt import Prompt +if TYPE_CHECKING: + from argilla import FeedbackDataset + _ULTRACM_TEMPLATE = get_template("ultracm.jinja2") @@ -52,3 +55,19 @@ def parse_output(self, output: str) -> CritiqueTaskOutput: # type: ignore score=float(match.group(1)), critique=match.group(2).strip(), ) + + def to_argilla_dataset( + self, + dataset_row: Dict[str, Any], + generations_column: str = "generations", + score_column: str = "score", + critique_column: str = "critique", + score_values: Optional[List[int]] = None, + ) -> "FeedbackDataset": + return super().to_argilla_dataset( + dataset_row=dataset_row, + generations_column=generations_column, + score_column=score_column, + critique_column=critique_column, + score_values=score_values or [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + ) diff --git a/src/distilabel/tasks/text_generation/base.py b/src/distilabel/tasks/text_generation/base.py index 8eb55ffeef..3c94aa7724 100644 --- a/src/distilabel/tasks/text_generation/base.py +++ b/src/distilabel/tasks/text_generation/base.py @@ -30,8 +30,7 @@ import argilla as rg if TYPE_CHECKING: - from argilla import FeedbackDataset - from argilla.client.feedback.schemas.records import FeedbackRecord + from argilla import FeedbackDataset, FeedbackRecord @dataclass diff --git a/src/distilabel/tasks/text_generation/self_instruct.py b/src/distilabel/tasks/text_generation/self_instruct.py index e1fccf3598..7f20c191ac 100644 --- a/src/distilabel/tasks/text_generation/self_instruct.py +++ b/src/distilabel/tasks/text_generation/self_instruct.py @@ -30,8 +30,7 @@ import argilla as rg if TYPE_CHECKING: - from argilla import FeedbackDataset - from argilla.client.feedback.schemas.records import FeedbackRecord + from argilla import FeedbackDataset, FeedbackRecord _SELF_INSTRUCT_TEMPLATE = get_template("self-instruct.jinja2")