Skip to content

Commit

Permalink
Update tests with new keyword for structured generation
Browse files Browse the repository at this point in the history
  • Loading branch information
plaguss committed Jun 6, 2024
1 parent 7639f2f commit 1fbe52a
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
8 changes: 3 additions & 5 deletions src/distilabel/steps/tasks/structured_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,11 @@
# limitations under the License.

import warnings
from typing import Any, Dict, Final, List, Union
from typing import Any, Dict, List, Union

from distilabel.steps.tasks.base import Task
from distilabel.steps.tasks.typing import StructuredInput

STRUCTURED_OUTPUT_COLUMN: Final[str] = "structured_output"


class StructuredGeneration(Task):
"""Generate structured content for a given `instruction` using an `LLM`.
Expand Down Expand Up @@ -63,7 +61,7 @@ def inputs(self) -> List[str]:
"""The input for the task are the `instruction` and the `structured_output`.
Optionally, if the `use_system_prompt` flag is set to True, then the
`system_prompt` will be used too."""
columns = ["instruction", STRUCTURED_OUTPUT_COLUMN]
columns = ["instruction", "structured_output"]
if self.use_system_prompt:
columns = ["system_prompt"] + columns
return columns
Expand All @@ -89,7 +87,7 @@ def format_input(self, input: Dict[str, Any]) -> StructuredInput:
stacklevel=2,
)

return (messages, input.get(STRUCTURED_OUTPUT_COLUMN, None)) # type: ignore
return (messages, input.get("structured_output", None)) # type: ignore

@property
def outputs(self) -> List[str]:
Expand Down
16 changes: 8 additions & 8 deletions tests/unit/steps/tasks/test_structured_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@ def test_format_input(self) -> None:
{
"instruction": "test",
"system_prompt": "test",
"grammar": {"type": "regex", "value": r"[a-zA-Z]+"},
"structured_output": {"format": "regex", "schema": r"[a-zA-Z]+"},
}
) == (
[{"role": "user", "content": "test"}],
{"type": "regex", "value": r"[a-zA-Z]+"},
{"format": "regex", "schema": r"[a-zA-Z]+"},
)

# 2. Not including the `grammar` field within the input
Expand Down Expand Up @@ -92,9 +92,9 @@ def test_process(self) -> None:
[
{
"instruction": "test",
"grammar": {
"type": "json",
"value": {
"structured_output": {
"format": "json",
"schema": {
"properties": {
"test": {"title": "Test", "type": "string"}
},
Expand All @@ -109,9 +109,9 @@ def test_process(self) -> None:
) == [
{
"instruction": "test",
"grammar": {
"type": "json",
"value": {
"structured_output": {
"format": "json",
"schema": {
"properties": {"test": {"title": "Test", "type": "string"}},
"required": ["test"],
"title": "Test",
Expand Down

0 comments on commit 1fbe52a

Please sign in to comment.