Skip to content

Commit

Permalink
Fix unit tests after updating add_raw_output
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed Jun 4, 2024
1 parent e950db8 commit d1e00be
Show file tree
Hide file tree
Showing 9 changed files with 54 additions and 6 deletions.
8 changes: 7 additions & 1 deletion tests/unit/steps/argilla/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,13 @@ def test_process(self) -> None:
step._rg_dataset = MockFeedbackDataset # type: ignore

assert list(step.process([{"instruction": "test", "generation": "test"}])) == [
[{"instruction": "test", "generation": "test"}]
[
{
"instruction": "test",
"generation": "test",
"distilabel_metadata": {"raw_output_task": "output"},
}
]
]
assert len(step._rg_dataset.records) == 1

Expand Down
7 changes: 6 additions & 1 deletion tests/unit/steps/tasks/evol_instruct/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
task.load()
assert task.dump() == {
"name": "task",
"add_raw_output": False,
"add_raw_output": True,
"input_mappings": task.input_mappings,
"output_mappings": task.output_mappings,
"input_batch_size": task.input_batch_size,
Expand Down Expand Up @@ -163,6 +163,11 @@ def test_serialization(self, dummy_llm: LLM) -> None:
}
],
},
{
"description": "Whether to include the raw output of the LLM in the output.",
"name": "add_raw_output",
"optional": True,
},
{
"name": "num_generations",
"optional": True,
Expand Down
7 changes: 6 additions & 1 deletion tests/unit/steps/tasks/evol_instruct/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"name": task.llm.__class__.__name__,
},
},
"add_raw_output": False,
"add_raw_output": True,
"input_mappings": task.input_mappings,
"output_mappings": task.output_mappings,
"batch_size": task.batch_size,
Expand Down Expand Up @@ -158,6 +158,11 @@ def test_serialization(self, dummy_llm: LLM) -> None:
},
],
},
{
"description": "Whether to include the raw output of the LLM in the output.",
"name": "add_raw_output",
"optional": True,
},
{
"name": "num_generations",
"optional": True,
Expand Down
9 changes: 7 additions & 2 deletions tests/unit/steps/tasks/evol_quality/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def test_serialization(self, dummy_llm: LLM) -> None:
task.load()
assert task.dump() == {
"name": "task",
"add_raw_output": False,
"add_raw_output": True,
"input_mappings": task.input_mappings,
"output_mappings": task.output_mappings,
"input_batch_size": task.input_batch_size,
Expand Down Expand Up @@ -112,9 +112,14 @@ def test_serialization(self, dummy_llm: LLM) -> None:
"name": "generation_kwargs",
"description": "The kwargs to be propagated to either `generate` or `agenerate` methods within each `LLM`.",
"keys": [],
}
},
],
},
{
"description": "Whether to include the raw output of the LLM in the output.",
"name": "add_raw_output",
"optional": True,
},
{
"name": "num_generations",
"optional": True,
Expand Down
15 changes: 14 additions & 1 deletion tests/unit/steps/tasks/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,19 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
"instruction": "test",
"output": "output",
"model_name": "test",
"distilabel_metadata": {"raw_output_task": "output"},
},
{
"instruction": "test",
"output": "output",
"model_name": "test",
"distilabel_metadata": {"raw_output_task": "output"},
},
{
"instruction": "test",
"output": "output",
"model_name": "test",
"distilabel_metadata": {"raw_output_task": "output"},
},
],
),
Expand All @@ -114,6 +117,11 @@ def test_with_errors(self, caplog: pytest.LogCaptureFixture) -> None:
"instruction": "test",
"output": ["output", "output", "output"],
"model_name": "test",
"distilabel_metadata": [
{"raw_output_task": "output"},
{"raw_output_task": "output"},
{"raw_output_task": "output"},
],
},
],
),
Expand Down Expand Up @@ -188,7 +196,7 @@ def test_serialization(self) -> None:
task = DummyTask(name="task", llm=llm, pipeline=pipeline)
assert task.dump() == {
"name": "task",
"add_raw_output": False,
"add_raw_output": True,
"input_mappings": {},
"output_mappings": {},
"input_batch_size": 50,
Expand Down Expand Up @@ -224,6 +232,11 @@ def test_serialization(self) -> None:
},
],
},
{
"description": "Whether to include the raw output of the LLM in the output.",
"name": "add_raw_output",
"optional": True,
},
{
"name": "num_generations",
"description": "The number of generations to be produced per input.",
Expand Down
3 changes: 3 additions & 0 deletions tests/unit/steps/tasks/test_instruction_backtranslation.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,5 +86,8 @@ def test_process(self) -> None:
"score": 1,
"reason": "This is the reason.",
"model_name": "instruction-backtranslation-model",
"distilabel_metadata": {
"raw_output_instruction-backtranslation": "This is the reason. Score: 1"
},
}
]
1 change: 1 addition & 0 deletions tests/unit/steps/tasks/test_structured_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,5 +120,6 @@ def test_process(self) -> None:
},
"generation": '{"test": "output"}',
"model_name": "test",
"distilabel_metadata": {"raw_output_task": '{"test": "output"}'},
}
]
4 changes: 4 additions & 0 deletions tests/unit/steps/tasks/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def test_process(self) -> None:
"instruction": "test",
"generation": "output",
"model_name": "test",
"distilabel_metadata": {
"raw_output_task": "output",
},
}
]

Expand Down Expand Up @@ -139,5 +142,6 @@ def test_process(self) -> None:
"messages": [{"role": "user", "content": "Tell me a joke."}],
"generation": "output",
"model_name": "test",
"distilabel_metadata": {"raw_output_task": "output"},
}
]
6 changes: 6 additions & 0 deletions tests/unit/steps/tasks/test_ultrafeedback.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def test_process_with_simple_aspect(self) -> None:
"ratings": [1, 2],
"rationales": ["text", "text"],
"model_name": "ultrafeedback-model",
"distilabel_metadata": {
"raw_output_ultrafeedback": "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text"
},
}
]

Expand All @@ -89,5 +92,8 @@ def test_process_with_complex_aspect(self) -> None:
"ratings": [1, 2],
"rationales-for-ratings": ["text", "text"],
"model_name": "ultrafeedback-model",
"distilabel_metadata": {
"raw_output_ultrafeedback": "Type: 1\nRationale: text\nRating: 1\nRationale: text\n\nType: 2\nRationale: text\nRating: 2\nRationale: text"
},
}
]

0 comments on commit d1e00be

Please sign in to comment.