Skip to content

Commit

Permalink
tests: validate passing questions and field within format_input too (#…
Browse files Browse the repository at this point in the history
…1017)

Co-authored-by: Gabriel Martín Blázquez <[email protected]>
  • Loading branch information
davidberenstein1957 and gabrielmbmb authored Oct 7, 2024
1 parent 4848dd2 commit d5c0484
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions src/distilabel/steps/tasks/argilla_labeller.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,10 +382,13 @@ def format_input(
"""Format the input into a chat message.
Args:
input (Dict[str, Union[Dict[str, Any], Record, TextField, MultiLabelQuestion, LabelQuestion, RatingQuestion, TextQuestion]]): The input to format.
input: The input to format.
Returns:
ChatType: The formatted chat message.
The formatted chat message.
Raises:
ValueError: If question or fields are not provided.
"""
input_keys = list(self.inputs.keys())
record = input[input_keys[0]]
Expand All @@ -394,6 +397,11 @@ def format_input(
examples = input.get(input_keys[3], self.example_records)
guidelines = input.get(input_keys[4], self.guidelines)

if question is None:
raise ValueError("Question must be provided.")
if fields is None or any(field is None for field in fields):
raise ValueError("Fields must be provided.")

record = record.to_dict() if not isinstance(record, dict) else record
question = question.serialize() if not isinstance(question, dict) else question
fields = [
Expand All @@ -416,6 +424,7 @@ def format_input(
if examples
else False
)

prompt = self._template.render(
fields=formatted_fields,
question=formatted_question,
Expand Down

0 comments on commit d5c0484

Please sign in to comment.