From d5c048459ada32ba7b3356f3f79c39fa0e61f706 Mon Sep 17 00:00:00 2001 From: David Berenstein Date: Mon, 7 Oct 2024 13:00:18 +0200 Subject: [PATCH] tests: validate passing questions and field within format_input too (#1017) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Gabriel Martín Blázquez --- src/distilabel/steps/tasks/argilla_labeller.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/distilabel/steps/tasks/argilla_labeller.py b/src/distilabel/steps/tasks/argilla_labeller.py index dd4522813..d0874ed3d 100644 --- a/src/distilabel/steps/tasks/argilla_labeller.py +++ b/src/distilabel/steps/tasks/argilla_labeller.py @@ -382,10 +382,13 @@ def format_input( """Format the input into a chat message. Args: - input (Dict[str, Union[Dict[str, Any], Record, TextField, MultiLabelQuestion, LabelQuestion, RatingQuestion, TextQuestion]]): The input to format. + input: The input to format. Returns: - ChatType: The formatted chat message. + The formatted chat message. + + Raises: + ValueError: If question or fields are not provided. """ input_keys = list(self.inputs.keys()) record = input[input_keys[0]] @@ -394,6 +397,11 @@ def format_input( examples = input.get(input_keys[3], self.example_records) guidelines = input.get(input_keys[4], self.guidelines) + if question is None: + raise ValueError("Question must be provided.") + if fields is None or any(field is None for field in fields): + raise ValueError("Fields must be provided.") + record = record.to_dict() if not isinstance(record, dict) else record question = question.serialize() if not isinstance(question, dict) else question fields = [ @@ -416,6 +424,7 @@ def format_input( if examples else False ) + prompt = self._template.render( fields=formatted_fields, question=formatted_question,