Skip to content

Commit

Permalink
Fix tests after refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
plaguss committed Jun 6, 2024
1 parent 76be9a7 commit ed319dc
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 12 deletions.
1 change: 0 additions & 1 deletion tests/unit/llms/test_vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ def test_serialization(self, _: MagicMock) -> None:
_dump = {
"model": "gemini-1.0-pro",
"generation_kwargs": {},
"structured_output": None,
"type_info": {
"module": "distilabel.llms.vertexai",
"name": "VertexAILLM",
Expand Down
1 change: 0 additions & 1 deletion tests/unit/steps/tasks/evol_instruct/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"input_batch_size": task.input_batch_size,
"llm": {
"generation_kwargs": {},
"structured_output": None,
"type_info": {
"module": task.llm.__module__,
"name": task.llm.__class__.__name__,
Expand Down
1 change: 0 additions & 1 deletion tests/unit/steps/tasks/evol_instruct/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"name": "task",
"llm": {
"generation_kwargs": {},
"structured_output": None,
"type_info": {
"module": task.llm.__class__.__module__,
"name": task.llm.__class__.__name__,
Expand Down
1 change: 0 additions & 1 deletion tests/unit/steps/tasks/evol_quality/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"input_batch_size": task.input_batch_size,
"llm": {
"generation_kwargs": {},
"structured_output": None,
"type_info": {
"module": task.llm.__module__,
"name": task.llm.__class__.__name__,
Expand Down
17 changes: 9 additions & 8 deletions tests/unit/steps/tasks/structured_outputs/test_outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@
import pytest
from distilabel.llms.huggingface.transformers import TransformersLLM
from distilabel.steps.tasks.structured_outputs.outlines import (
StructuredOutputType,
# StructuredOutputType,
model_to_schema,
)
from distilabel.steps.tasks.typing import OutlinesStructuredOutputType
from pydantic import BaseModel


Expand Down Expand Up @@ -88,18 +89,14 @@ class DummyUserTest(BaseModel):


class TestOutlinesIntegration:
# @pytest.mark.skipif(
# not DISTILABEL_RUN_SLOW_TESTS,
# reason="Slow tests, run locally when needed.",
# )
@pytest.mark.parametrize(
"format, schema, prompt",
[
(
"json",
DummyUserTest,
"Create a user profile with the fields name, last_name and id",
), #
),
(
"json",
model_to_schema(DummyUserTest),
Expand All @@ -117,7 +114,9 @@ def test_generation(
) -> None:
llm = TransformersLLM(
model="openaccess-ai-collective/tiny-mistral",
structured_output=StructuredOutputType(format=format, schema=schema),
structured_output=OutlinesStructuredOutputType(
format=format, schema=schema
),
)
llm.load()

Expand Down Expand Up @@ -154,7 +153,9 @@ def test_serialization(
) -> None:
llm = TransformersLLM(
model="openaccess-ai-collective/tiny-mistral",
structured_output=StructuredOutputType(format=format, schema=schema),
structured_output=OutlinesStructuredOutputType(
format=format, schema=schema
),
)
llm.load()
assert llm.dump() == dump
Expand Down

0 comments on commit ed319dc

Please sign in to comment.