Skip to content

Commit

Permalink
Update unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Jun 17, 2024
1 parent af34be1 commit 4307ec4
Show file tree
Hide file tree
Showing 13 changed files with 94 additions and 83 deletions.
26 changes: 18 additions & 8 deletions tests/unit/steps/tasks/utils.py → tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, List, Union
from typing import TYPE_CHECKING

from distilabel.llms.base import LLM
from distilabel.steps.tasks.typing import ChatType
import pytest
from distilabel.llms.base import AsyncLLM

if TYPE_CHECKING:
from distilabel.llms.typing import GenerateOutput
from distilabel.steps.tasks.typing import FormattedInput

class DummyLLM(LLM):

# Defined here too, so that the serde still works
class DummyLLM(AsyncLLM):
def load(self) -> None:
pass

@property
def model_name(self) -> str:
return "test"

def generate(
self, inputs: List["ChatType"], num_generations: int = 1, **kwargs: Any
) -> List[List[Union[str, None]]]:
return [["output" for _ in range(num_generations)] for _ in inputs]
async def agenerate(
self, input: "FormattedInput", num_generations: int = 1
) -> "GenerateOutput":
return ["output" for _ in range(num_generations)]


@pytest.fixture
def dummy_llm() -> AsyncLLM:
return DummyLLM()
2 changes: 1 addition & 1 deletion tests/unit/llms/test_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ async def test_agenerate_structured(
},
]
)
assert generation == sample_user.model_dump_json()
assert generation == [sample_user.model_dump_json()]

@pytest.mark.asyncio
async def test_generate(self, mock_async_client: mock.MagicMock) -> None:
Expand Down
61 changes: 61 additions & 0 deletions tests/unit/llms/test_moa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from distilabel.llms.moa import MOA_SYSTEM_PROMPT, MixtureOfAgents

from tests.unit.conftest import DummyLLM


class TestMixtureOfAgents:
def test_model_name(self) -> None:
llm = MixtureOfAgents(
aggregator_llm=DummyLLM(),
proposers_llms=[DummyLLM(), DummyLLM(), DummyLLM()],
)

assert llm.model_name == "moa-test-test-test-test"

def test_build_moa_system_prompt(self) -> None:
llm = MixtureOfAgents(
aggregator_llm=DummyLLM(),
proposers_llms=[DummyLLM(), DummyLLM(), DummyLLM()],
)

system_prompt = llm._build_moa_system_prompt(
prev_outputs=["output1", "output2", "output3"]
)

assert (
system_prompt == f"{MOA_SYSTEM_PROMPT}\n1. output1\n2. output2\n3. output3"
)

def test_inject_moa_system_prompt(self) -> None:
llm = MixtureOfAgents(
aggregator_llm=DummyLLM(),
proposers_llms=[DummyLLM(), DummyLLM(), DummyLLM()],
)

results = llm._inject_moa_system_prompt(
input=[
{"role": "system", "content": "I'm a system prompt."},
],
prev_outputs=["output1", "output2", "output3"],
)

assert results == [
{
"role": "system",
"content": f"{MOA_SYSTEM_PROMPT}\n1. output1\n2. output2\n3. output3\n\nI'm a system prompt.",
}
]
2 changes: 1 addition & 1 deletion tests/unit/steps/tasks/benchmarks/test_arena_hard.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.benchmarks.arena_hard import ArenaHard, ArenaHardResults

from tests.unit.steps.tasks.utils import DummyLLM
from tests.unit.conftest import DummyLLM


class TestArenaHard:
Expand Down
55 changes: 0 additions & 55 deletions tests/unit/steps/tasks/conftest.py

This file was deleted.

17 changes: 6 additions & 11 deletions tests/unit/steps/tasks/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from distilabel.steps.tasks.base import Task
from pydantic import ValidationError

from tests.unit.steps.tasks.utils import DummyLLM
from tests.unit.conftest import DummyLLM

if TYPE_CHECKING:
from distilabel.steps.tasks.typing import ChatType
Expand Down Expand Up @@ -156,7 +156,7 @@ def test_process_with_runtime_parameters(self) -> None:
assert task.llm.runtime_parameters_names == {
"runtime_parameter": False,
"runtime_parameter_optional": True,
"generation_kwargs": {"kwargs": False},
"generation_kwargs": {},
}

# 2. Runtime parameters in init
Expand All @@ -171,7 +171,7 @@ def test_process_with_runtime_parameters(self) -> None:
assert task.llm.runtime_parameters_names == {
"runtime_parameter": False,
"runtime_parameter_optional": True,
"generation_kwargs": {"kwargs": False},
"generation_kwargs": {},
}

# 3. Runtime parameters in init superseded by runtime parameters
Expand All @@ -187,7 +187,7 @@ def test_process_with_runtime_parameters(self) -> None:
assert task.llm.runtime_parameters_names == {
"runtime_parameter": False,
"runtime_parameter_optional": True,
"generation_kwargs": {"kwargs": False},
"generation_kwargs": {},
}

def test_serialization(self) -> None:
Expand All @@ -203,7 +203,7 @@ def test_serialization(self) -> None:
"llm": {
"generation_kwargs": {},
"type_info": {
"module": "tests.unit.steps.tasks.utils",
"module": "tests.unit.conftest",
"name": "DummyLLM",
},
},
Expand All @@ -221,12 +221,7 @@ def test_serialization(self) -> None:
{
"description": "The kwargs to be propagated to either `generate` or "
"`agenerate` methods within each `LLM`.",
"keys": [
{
"name": "kwargs",
"optional": False,
},
],
"keys": [],
"name": "generation_kwargs",
},
],
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/steps/tasks/test_complexity_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.complexity_scorer import ComplexityScorer

from tests.unit.steps.tasks.utils import DummyLLM
from tests.unit.conftest import DummyLLM


class TestComplexityScorer:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/steps/tasks/test_genstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.genstruct import Genstruct

from tests.unit.steps.tasks.utils import DummyLLM
from tests.unit.conftest import DummyLLM


class TestGenstruct:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/steps/tasks/test_prometheus_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from jinja2 import Template
from pydantic import ValidationError

from tests.unit.steps.tasks.utils import DummyLLM
from tests.unit.conftest import DummyLLM


def load_template(template: str) -> Template:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/steps/tasks/test_quality_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.quality_scorer import QualityScorer

from tests.unit.steps.tasks.utils import DummyLLM
from tests.unit.conftest import DummyLLM


class TestQualityScorer:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/steps/tasks/test_self_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.self_instruct import SelfInstruct

from tests.unit.steps.tasks.utils import DummyLLM
from tests.unit.conftest import DummyLLM


class TestSelfInstruct:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/steps/tasks/test_sentence_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
GenerationAction,
)

from tests.unit.steps.tasks.utils import DummyLLM
from tests.unit.conftest import DummyLLM


class TestGenerateSentencePair:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/steps/tasks/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from distilabel.pipeline.local import Pipeline
from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration

from tests.unit.steps.tasks.utils import DummyLLM
from tests.unit.conftest import DummyLLM


class TestTextGeneration:
Expand Down

0 comments on commit 4307ec4

Please sign in to comment.