Skip to content

Commit

Permalink
Complexity scorer default structured output (#870)
Browse files Browse the repository at this point in the history
* Add default structured output for GenerateSentencePair task

* Move default behavior to base class

* Add docstrings to the methods and move json schemas to the class method

* Add tests for default structured outputs in sentence transformers task

* Add control for parsing errors on JSON data

* Add default structured output for ComplexityScorer task

* Refactor code per code review, to simplify just creating the default schemas

* Add extra check to avoid setting the structured output if the method wasn't overriden

* Refactor get_structured_output to return just the schema

* Add reference for the JSON schema
  • Loading branch information
plaguss authored Aug 9, 2024
1 parent 5e5e7c3 commit 7702e24
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 5 deletions.
78 changes: 78 additions & 0 deletions src/distilabel/steps/tasks/complexity_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@

from typing import TYPE_CHECKING, Any, Dict, List, Union

import orjson
from jinja2 import Template
from pydantic import PrivateAttr
from typing_extensions import override

from distilabel.steps.tasks.base import Task

Expand Down Expand Up @@ -86,6 +88,31 @@ class ComplexityScorer(Task):
# [{'instructions': ['plain instruction', 'highly complex instruction'], 'model_name': 'test', 'scores': [1, 5], 'distilabel_metadata': {'raw_output_complexity_scorer_0': 'output'}}]
```
Generate structured output with default schema:
```python
from distilabel.steps.tasks import ComplexityScorer
from distilabel.llms.huggingface import InferenceEndpointsLLM
# Consider this as a placeholder for your actual LLM.
scorer = ComplexityScorer(
llm=InferenceEndpointsLLM(
model_id="mistralai/Mistral-7B-Instruct-v0.2",
),
use_default_structured_output=use_default_structured_output
)
scorer.load()
result = next(
scorer.process(
[{"instructions": ["plain instruction", "highly complex instruction"]}]
)
)
# result
# [{'instructions': ['plain instruction', 'highly complex instruction'], 'model_name': 'test', 'scores': [1, 2], 'distilabel_metadata': {'raw_output_complexity_scorer_0': '{ \n "scores": [\n 1, \n 2\n ]\n}'}}]
```
Citations:
```
Expand Down Expand Up @@ -153,6 +180,9 @@ def format_output(
if output is None:
return {"scores": [None] * len(input["instructions"])}

if self.use_default_structured_output:
return self._format_structured_output(output, input)

scores = []
score_lines = output.split("\n")
for i, line in enumerate(score_lines):
Expand All @@ -162,3 +192,51 @@ def format_output(
if i == len(input["instructions"]) - 1:
break
return {"scores": scores}

@override
def get_structured_output(self) -> Dict[str, Any]:
"""Creates the json schema to be passed to the LLM, to enforce generating
a dictionary with the output which can be directly parsed as a python dictionary.
The schema corresponds to the following:
```python
from pydantic import BaseModel
from typing import List
class SchemaComplexityScorer(BaseModel):
scores: List[int]
```
Returns:
JSON Schema of the response to enforce.
"""
return {
"properties": {
"scores": {
"items": {"type": "integer"},
"title": "Scores",
"type": "array",
}
},
"required": ["scores"],
"title": "SchemaComplexityScorer",
"type": "object",
}

def _format_structured_output(
self, output: str, input: Dict[str, Any]
) -> Dict[str, str]:
"""Parses the structured response, which should correspond to a dictionary
with either `positive`, or `positive` and `negative` keys.
Args:
output: The output from the `LLM`.
Returns:
Formatted output.
"""
try:
return orjson.loads(output)
except orjson.JSONDecodeError:
return {"scores": [None] * len(input["instructions"])}
24 changes: 19 additions & 5 deletions tests/unit/steps/tasks/test_complexity_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,29 +42,43 @@ def test_format_input(self) -> None:
]

@pytest.mark.parametrize(
"output, expected",
"output, use_default_structured_output, expected",
[
(
"[1] Score: 1\n[2] Score: 2\n[3] Score: 3\n",
False,
{"scores": [1.0, 2.0, 3.0]},
),
(
"[1] Score: 1\n[2] Score: 2\n[3] Score: 3\njfjfjfjjfjfjf this is noise from the llm\nlallalalala more noise\nand more noise",
False,
{"scores": [1.0, 2.0, 3.0]},
),
(
None,
False,
{"scores": [None, None, None]},
),
(
'{"scores":[1,2,3]}',
True,
{"scores": [1.0, 2.0, 3.0]},
),
(
"wrong",
True,
{"scores": [None, None, None]},
),
],
)
def test_format_output(
self, output: Union[str, None], expected: Dict[str, Any]
self,
output: Union[str, None],
use_default_structured_output: bool,
expected: Dict[str, Any],
) -> None:
task = ComplexityScorer(
name="complexity_scorer",
llm=DummyLLM(),
pipeline=Pipeline(name="unit-test-pipeline"),
llm=DummyLLM(), use_default_structured_output=use_default_structured_output
)
task.load()

Expand Down

0 comments on commit 7702e24

Please sign in to comment.