From 4307ec47f1d7203ed27c0a9e7bfc3e75a76e0c98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20Mart=C3=ADn=20Bl=C3=A1zquez?= Date: Mon, 17 Jun 2024 10:21:03 +0200 Subject: [PATCH] Update unit tests --- .../{steps/tasks/utils.py => conftest.py} | 26 +++++--- tests/unit/llms/test_cohere.py | 2 +- tests/unit/llms/test_moa.py | 61 +++++++++++++++++++ .../steps/tasks/benchmarks/test_arena_hard.py | 2 +- tests/unit/steps/tasks/conftest.py | 55 ----------------- tests/unit/steps/tasks/test_base.py | 17 ++---- .../steps/tasks/test_complexity_scorer.py | 2 +- tests/unit/steps/tasks/test_genstruct.py | 2 +- .../unit/steps/tasks/test_prometheus_eval.py | 2 +- tests/unit/steps/tasks/test_quality_scorer.py | 2 +- tests/unit/steps/tasks/test_self_instruct.py | 2 +- .../steps/tasks/test_sentence_transformers.py | 2 +- .../unit/steps/tasks/test_text_generation.py | 2 +- 13 files changed, 94 insertions(+), 83 deletions(-) rename tests/unit/{steps/tasks/utils.py => conftest.py} (56%) create mode 100644 tests/unit/llms/test_moa.py delete mode 100644 tests/unit/steps/tasks/conftest.py diff --git a/tests/unit/steps/tasks/utils.py b/tests/unit/conftest.py similarity index 56% rename from tests/unit/steps/tasks/utils.py rename to tests/unit/conftest.py index 989fb3ad5b..bbe6ca1ed4 100644 --- a/tests/unit/steps/tasks/utils.py +++ b/tests/unit/conftest.py @@ -12,13 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Union +from typing import TYPE_CHECKING -from distilabel.llms.base import LLM -from distilabel.steps.tasks.typing import ChatType +import pytest +from distilabel.llms.base import AsyncLLM +if TYPE_CHECKING: + from distilabel.llms.typing import GenerateOutput + from distilabel.steps.tasks.typing import FormattedInput -class DummyLLM(LLM): + +# Defined here too, so that the serde still works +class DummyLLM(AsyncLLM): def load(self) -> None: pass @@ -26,7 +31,12 @@ def load(self) -> None: def model_name(self) -> str: return "test" - def generate( - self, inputs: List["ChatType"], num_generations: int = 1, **kwargs: Any - ) -> List[List[Union[str, None]]]: - return [["output" for _ in range(num_generations)] for _ in inputs] + async def agenerate( + self, input: "FormattedInput", num_generations: int = 1 + ) -> "GenerateOutput": + return ["output" for _ in range(num_generations)] + + +@pytest.fixture +def dummy_llm() -> AsyncLLM: + return DummyLLM() diff --git a/tests/unit/llms/test_cohere.py b/tests/unit/llms/test_cohere.py index 3cba9611d8..a16d904e11 100644 --- a/tests/unit/llms/test_cohere.py +++ b/tests/unit/llms/test_cohere.py @@ -98,7 +98,7 @@ async def test_agenerate_structured( }, ] ) - assert generation == sample_user.model_dump_json() + assert generation == [sample_user.model_dump_json()] @pytest.mark.asyncio async def test_generate(self, mock_async_client: mock.MagicMock) -> None: diff --git a/tests/unit/llms/test_moa.py b/tests/unit/llms/test_moa.py new file mode 100644 index 0000000000..dca065a861 --- /dev/null +++ b/tests/unit/llms/test_moa.py @@ -0,0 +1,61 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from distilabel.llms.moa import MOA_SYSTEM_PROMPT, MixtureOfAgents + +from tests.unit.conftest import DummyLLM + + +class TestMixtureOfAgents: + def test_model_name(self) -> None: + llm = MixtureOfAgents( + aggregator_llm=DummyLLM(), + proposers_llms=[DummyLLM(), DummyLLM(), DummyLLM()], + ) + + assert llm.model_name == "moa-test-test-test-test" + + def test_build_moa_system_prompt(self) -> None: + llm = MixtureOfAgents( + aggregator_llm=DummyLLM(), + proposers_llms=[DummyLLM(), DummyLLM(), DummyLLM()], + ) + + system_prompt = llm._build_moa_system_prompt( + prev_outputs=["output1", "output2", "output3"] + ) + + assert ( + system_prompt == f"{MOA_SYSTEM_PROMPT}\n1. output1\n2. output2\n3. output3" + ) + + def test_inject_moa_system_prompt(self) -> None: + llm = MixtureOfAgents( + aggregator_llm=DummyLLM(), + proposers_llms=[DummyLLM(), DummyLLM(), DummyLLM()], + ) + + results = llm._inject_moa_system_prompt( + input=[ + {"role": "system", "content": "I'm a system prompt."}, + ], + prev_outputs=["output1", "output2", "output3"], + ) + + assert results == [ + { + "role": "system", + "content": f"{MOA_SYSTEM_PROMPT}\n1. output1\n2. output2\n3. output3\n\nI'm a system prompt.", + } + ] diff --git a/tests/unit/steps/tasks/benchmarks/test_arena_hard.py b/tests/unit/steps/tasks/benchmarks/test_arena_hard.py index 50258db668..40666e4402 100644 --- a/tests/unit/steps/tasks/benchmarks/test_arena_hard.py +++ b/tests/unit/steps/tasks/benchmarks/test_arena_hard.py @@ -20,7 +20,7 @@ from distilabel.pipeline.local import Pipeline from distilabel.steps.tasks.benchmarks.arena_hard import ArenaHard, ArenaHardResults -from tests.unit.steps.tasks.utils import DummyLLM +from tests.unit.conftest import DummyLLM class TestArenaHard: diff --git a/tests/unit/steps/tasks/conftest.py b/tests/unit/steps/tasks/conftest.py deleted file mode 100644 index da1493c9ce..0000000000 --- a/tests/unit/steps/tasks/conftest.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright 2023-present, Argilla, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import TYPE_CHECKING, List - -import pytest -from distilabel.llms.base import LLM - -if TYPE_CHECKING: - from distilabel.llms.typing import GenerateOutput - from distilabel.steps.tasks.typing import ChatType - - -@pytest.fixture -def dummy_llm() -> LLM: - class DummyLLM(LLM): - def load(self) -> None: - pass - - @property - def model_name(self) -> str: - return "test" - - def generate( # type: ignore - self, inputs: List["ChatType"], num_generations: int = 1 - ) -> List["GenerateOutput"]: - return [["output"] for _ in inputs] - - return DummyLLM() - - -# Defined here too, so that the serde still works -class DummyLLM(LLM): - def load(self) -> None: - pass - - @property - def model_name(self) -> str: - return "test" - - def generate( # type: ignore - self, inputs: List["ChatType"], num_generations: int = 1 - ) -> List["GenerateOutput"]: - return [["output"] for _ in inputs] diff --git a/tests/unit/steps/tasks/test_base.py b/tests/unit/steps/tasks/test_base.py index e24dd0fefb..c5ff87e1e0 100644 --- a/tests/unit/steps/tasks/test_base.py +++ b/tests/unit/steps/tasks/test_base.py @@ -22,7 +22,7 @@ from distilabel.steps.tasks.base import Task from pydantic import ValidationError -from tests.unit.steps.tasks.utils import DummyLLM +from tests.unit.conftest import DummyLLM if TYPE_CHECKING: from distilabel.steps.tasks.typing import ChatType @@ -156,7 +156,7 @@ def test_process_with_runtime_parameters(self) -> None: assert task.llm.runtime_parameters_names == { "runtime_parameter": False, "runtime_parameter_optional": True, - "generation_kwargs": {"kwargs": False}, + "generation_kwargs": {}, } # 2. Runtime parameters in init @@ -171,7 +171,7 @@ def test_process_with_runtime_parameters(self) -> None: assert task.llm.runtime_parameters_names == { "runtime_parameter": False, "runtime_parameter_optional": True, - "generation_kwargs": {"kwargs": False}, + "generation_kwargs": {}, } # 3. Runtime parameters in init superseded by runtime parameters @@ -187,7 +187,7 @@ def test_process_with_runtime_parameters(self) -> None: assert task.llm.runtime_parameters_names == { "runtime_parameter": False, "runtime_parameter_optional": True, - "generation_kwargs": {"kwargs": False}, + "generation_kwargs": {}, } def test_serialization(self) -> None: @@ -203,7 +203,7 @@ def test_serialization(self) -> None: "llm": { "generation_kwargs": {}, "type_info": { - "module": "tests.unit.steps.tasks.utils", + "module": "tests.unit.conftest", "name": "DummyLLM", }, }, @@ -221,12 +221,7 @@ def test_serialization(self) -> None: { "description": "The kwargs to be propagated to either `generate` or " "`agenerate` methods within each `LLM`.", - "keys": [ - { - "name": "kwargs", - "optional": False, - }, - ], + "keys": [], "name": "generation_kwargs", }, ], diff --git a/tests/unit/steps/tasks/test_complexity_scorer.py b/tests/unit/steps/tasks/test_complexity_scorer.py index a47a16445d..ec0575d745 100644 --- a/tests/unit/steps/tasks/test_complexity_scorer.py +++ b/tests/unit/steps/tasks/test_complexity_scorer.py @@ -18,7 +18,7 @@ from distilabel.pipeline.local import Pipeline from distilabel.steps.tasks.complexity_scorer import ComplexityScorer -from tests.unit.steps.tasks.utils import DummyLLM +from tests.unit.conftest import DummyLLM class TestComplexityScorer: diff --git a/tests/unit/steps/tasks/test_genstruct.py b/tests/unit/steps/tasks/test_genstruct.py index 8ecc9d2d58..12878b9f26 100644 --- a/tests/unit/steps/tasks/test_genstruct.py +++ b/tests/unit/steps/tasks/test_genstruct.py @@ -18,7 +18,7 @@ from distilabel.pipeline.local import Pipeline from distilabel.steps.tasks.genstruct import Genstruct -from tests.unit.steps.tasks.utils import DummyLLM +from tests.unit.conftest import DummyLLM class TestGenstruct: diff --git a/tests/unit/steps/tasks/test_prometheus_eval.py b/tests/unit/steps/tasks/test_prometheus_eval.py index e5a4ad8590..31a437fdab 100644 --- a/tests/unit/steps/tasks/test_prometheus_eval.py +++ b/tests/unit/steps/tasks/test_prometheus_eval.py @@ -27,7 +27,7 @@ from jinja2 import Template from pydantic import ValidationError -from tests.unit.steps.tasks.utils import DummyLLM +from tests.unit.conftest import DummyLLM def load_template(template: str) -> Template: diff --git a/tests/unit/steps/tasks/test_quality_scorer.py b/tests/unit/steps/tasks/test_quality_scorer.py index 0a3db8261b..608631e9a2 100644 --- a/tests/unit/steps/tasks/test_quality_scorer.py +++ b/tests/unit/steps/tasks/test_quality_scorer.py @@ -18,7 +18,7 @@ from distilabel.pipeline.local import Pipeline from distilabel.steps.tasks.quality_scorer import QualityScorer -from tests.unit.steps.tasks.utils import DummyLLM +from tests.unit.conftest import DummyLLM class TestQualityScorer: diff --git a/tests/unit/steps/tasks/test_self_instruct.py b/tests/unit/steps/tasks/test_self_instruct.py index 8525b88e6b..e3378e7e93 100644 --- a/tests/unit/steps/tasks/test_self_instruct.py +++ b/tests/unit/steps/tasks/test_self_instruct.py @@ -15,7 +15,7 @@ from distilabel.pipeline.local import Pipeline from distilabel.steps.tasks.self_instruct import SelfInstruct -from tests.unit.steps.tasks.utils import DummyLLM +from tests.unit.conftest import DummyLLM class TestSelfInstruct: diff --git a/tests/unit/steps/tasks/test_sentence_transformers.py b/tests/unit/steps/tasks/test_sentence_transformers.py index e63b14bc83..2f81240755 100644 --- a/tests/unit/steps/tasks/test_sentence_transformers.py +++ b/tests/unit/steps/tasks/test_sentence_transformers.py @@ -23,7 +23,7 @@ GenerationAction, ) -from tests.unit.steps.tasks.utils import DummyLLM +from tests.unit.conftest import DummyLLM class TestGenerateSentencePair: diff --git a/tests/unit/steps/tasks/test_text_generation.py b/tests/unit/steps/tasks/test_text_generation.py index 545cf6a7b8..c98adb00e5 100644 --- a/tests/unit/steps/tasks/test_text_generation.py +++ b/tests/unit/steps/tasks/test_text_generation.py @@ -16,7 +16,7 @@ from distilabel.pipeline.local import Pipeline from distilabel.steps.tasks.text_generation import ChatGeneration, TextGeneration -from tests.unit.steps.tasks.utils import DummyLLM +from tests.unit.conftest import DummyLLM class TestTextGeneration: