Skip to content

Commit

Permalink
Fix default structured output (#892)
Browse files Browse the repository at this point in the history
* Add check for dependencies for structured outputs and change default value of structured outputs

* Update tests with serialized default structured output

---------

Co-authored-by: Gabriel Martín Blázquez <[email protected]>
  • Loading branch information
plaguss and gabrielmbmb authored Aug 13, 2024
1 parent 8916ff2 commit 75baf64
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 1 deletion.
13 changes: 12 additions & 1 deletion src/distilabel/steps/tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import importlib
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Dict, List, Union

Expand Down Expand Up @@ -63,7 +64,7 @@ class _Task(_Step, ABC):
num_generations: RuntimeParameter[int] = Field(
default=1, description="The number of generations to be produced per input."
)
use_default_structured_output: bool = True
use_default_structured_output: bool = False

def load(self) -> None:
"""Loads the LLM via the `LLM.load()` method."""
Expand Down Expand Up @@ -173,14 +174,24 @@ def _set_default_structured_output(self) -> None:
from distilabel.llms import InferenceEndpointsLLM
from distilabel.llms.base import AsyncLLM

def check_dependency(module_name: str) -> None:
if not importlib.util.find_spec(module_name):
raise ImportError(
f"`{module_name}` is not installed and is needed for the structured generation with this LLM."
f" Please install it using `pip install {module_name}`."
)

dependency = "outlines"
structured_output = {"schema": schema}
# To determine instructor or outlines format
if not (
isinstance(self.llm, AsyncLLM)
and not isinstance(self.llm, InferenceEndpointsLLM)
):
dependency = "instructor"
structured_output.update({"format": "json"})

check_dependency(dependency)
self.llm.structured_output = structured_output

def get_structured_output(self) -> Union[Dict[str, Any], None]:
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/steps/tasks/evol_instruct/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"input_batch_size": task.input_batch_size,
"llm": {
"generation_kwargs": {},
"structured_output": None,
"type_info": {
"module": task.llm.__module__,
"name": task.llm.__class__.__name__,
Expand All @@ -152,6 +153,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"INCREASED_REASONING_STEPS": "I want you act as a Prompt Rewriter.\n\nYour objective is to rewrite a given prompt into a more complex version to make those famous AI systems (e.g., chatgpt and GPT4) a bit harder to handle.\n\nBut the rewritten prompt must be reasonable and must be understood and responded by humans.\n\nYour rewriting cannot omit the non-text parts such as the table and code in #The Given Prompt#:. Also, please do not omit the input in #The Given Prompt#.\n\nYou SHOULD complicate the given prompt using the following method: \nIf #The Given Prompt# can be solved with just a few simple thinking processes, you can rewrite it to explicitly request multiple-step reasoning.\n\nYou should try your best not to make the #Rewritten Prompt# become verbose, #Rewritten Prompt# can only add 10 to 20 words into #The Given Prompt#.\n\n'#The Given Prompt#', '#Rewritten Prompt#', 'given prompt' and 'rewritten prompt' are not allowed to appear in #Rewritten Prompt#\n\n#The Given Prompt#:\n<PROMPT>\n#Rewritten Prompt#:\n\n",
"BREADTH": "I want you act as a Prompt Creator.\n\nYour goal is to draw inspiration from the #Given Prompt# to create a brand new prompt.\n\nThis new prompt should belong to the same domain as the #Given Prompt# but be even more rare.\n\nThe LENGTH and complexity of the #Created Prompt# should be similar to that of the #Given Prompt#.\n\nThe #Created Prompt# must be reasonable and must be understood and responded by humans.\n\n'#Given Prompt#', '#Created Prompt#', 'given prompt' and 'created prompt' are not allowed to appear in #Created Prompt#\n\n#Given Prompt#:\n<PROMPT>\n#Created Prompt#:\n\n",
},
"use_default_structured_output": False,
"seed": task.seed,
"runtime_parameters_info": [
{
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/steps/tasks/evol_instruct/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"name": "task",
"llm": {
"generation_kwargs": {},
"structured_output": None,
"type_info": {
"module": task.llm.__class__.__module__,
"name": task.llm.__class__.__name__,
Expand Down Expand Up @@ -148,6 +149,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"min_length": task.min_length,
"max_length": task.max_length,
"seed": task.seed,
"use_default_structured_output": False,
"runtime_parameters_info": [
{
"name": "resources",
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/steps/tasks/evol_quality/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"input_batch_size": task.input_batch_size,
"llm": {
"generation_kwargs": {},
"structured_output": None,
"type_info": {
"module": task.llm.__module__,
"name": task.llm.__class__.__name__,
Expand All @@ -117,6 +118,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"group_generations": task.group_generations,
"include_original_response": task.include_original_response,
"seed": task.seed,
"use_default_structured_output": False,
"runtime_parameters_info": [
{
"name": "resources",
Expand Down
1 change: 1 addition & 0 deletions tests/unit/steps/tasks/magpie/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,7 @@ def test_serialization(self) -> None:
"group_generations": False,
"add_raw_output": True,
"num_generations": 1,
"use_default_structured_output": False,
"runtime_parameters_info": [
{
"name": "llm",
Expand Down
1 change: 1 addition & 0 deletions tests/unit/steps/tasks/magpie/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def test_serialization(self) -> None:
"add_raw_output": True,
"num_generations": 1,
"num_rows": None,
"use_default_structured_output": False,
"runtime_parameters_info": [
{
"name": "llm",
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/steps/tasks/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,7 @@ def test_serialization(self) -> None:
"input_batch_size": 50,
"llm": {
"generation_kwargs": {},
"structured_output": None,
"type_info": {
"module": "tests.unit.conftest",
"name": "DummyLLM",
Expand Down Expand Up @@ -389,6 +390,7 @@ def test_serialization(self) -> None:
"module": "tests.unit.steps.tasks.test_base",
"name": "DummyTask",
},
"use_default_structured_output": False,
}

with Pipeline(name="unit-test-pipeline") as pipeline:
Expand Down

0 comments on commit 75baf64

Please sign in to comment.