Skip to content

Commit

Permalink
Fix StructuredGeneration examples and internal check (#912)
Browse files Browse the repository at this point in the history
* Fix error with instructor schema input

* Fix examples of structured generation

* Try inferring the type of format in case the user forgets informing about it
  • Loading branch information
plaguss authored Aug 22, 2024
1 parent 46d55ed commit 6576d1a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
5 changes: 4 additions & 1 deletion src/distilabel/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,10 @@ def _prepare_kwargs(
# We can deal with json schema or BaseModel, but we need to convert it to a BaseModel
# for the Instructor client.
schema = structured_output.get("schema", {})
if not issubclass(schema, BaseModel):

# If there's already a pydantic model, we don't need to do anything,
# otherwise, try to obtain one.
if not (inspect.isclass(schema) and issubclass(schema, BaseModel)):
from distilabel.steps.tasks.structured_outputs.utils import (
json_schema_to_model,
)
Expand Down
10 changes: 5 additions & 5 deletions src/distilabel/steps/tasks/structured_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ class StructuredGeneration(Task):
{
"instruction": "Create an RPG character",
"structured_output": {
"type": "json",
"value": {
"format": "json",
"schema": {
"properties": {
"name": {
"title": "Name",
Expand Down Expand Up @@ -105,7 +105,7 @@ class StructuredGeneration(Task):
)
```
Generate structured output from a regex pattern:
Generate structured output from a regex pattern (only works with LLMs that support regex, the providers using outlines):
```python
from distilabel.steps.tasks import StructuredGeneration
Expand All @@ -126,8 +126,8 @@ class StructuredGeneration(Task):
{
"instruction": "What's the weather like today in Seattle in Celsius degrees?",
"structured_output": {
"type": "regex",
"value": r"(\\d{1,2})°C"
"format": "regex",
"schema": r"(\\d{1,2})°C"
},
}
Expand Down
8 changes: 8 additions & 0 deletions src/distilabel/steps/tasks/structured_outputs/outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import importlib
import importlib.util
import inspect
import json
from typing import (
Any,
Expand Down Expand Up @@ -102,6 +103,13 @@ def prepare_guided_output(
format = structured_output.get("format")
schema = structured_output.get("schema")

# If schema not informed (may be forgotten), try infering it
if not format:
if isinstance(schema, dict) or inspect.isclass(schema):
format = "json"
elif isinstance(schema, str):
format = "regex"

if format == "json":
return {
"processor": json_processor(
Expand Down

0 comments on commit 6576d1a

Please sign in to comment.