-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Redirect import of CLAIR * Add jinja2 template for CLAIR * Add CLAIR task * Add tests for CLAIR task * Update example in docstrings * Add tutorial to reproduce CLAIR * Show new tutorial in the gallery and fix rendering issue in docstrings
- Loading branch information
Showing
8 changed files
with
382 additions
and
0 deletions.
There are no files selected for viewing
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Contrastive Learning From AI Revisions (CLAIR) | ||
|
||
["Anchored Preference Optimization and Contrastive Revisions: Addressing Underspecification in Alignment"](https://huggingface.co/papers/2408.06266) introduces both Contrastive | ||
Learning from AI Revisions (CLAIR), a data-creation method which leads to more contrastive preference pairs, and Anchored Preference Optimization (APO), a controllable and more stable alignment objective. While APO can be found in [TRL](https://huggingface.co/docs/trl/dpo_trainer#loss-functions), we have implemented a task for CLAIR in `distilabel`. | ||
|
||
CLAIR is a method for creating preference pairs which minimally revises one output to express a preference, resulting in a more precise learning signal as opposed to conventional methods which use a judge to select a preferred response. | ||
|
||
![CLAIR overview](../../../assets/pipelines/clair.png) | ||
|
||
The athors from the original paper shared a [collection of datasets from CLAIR and APO](https://huggingface.co/collections/ContextualAI/clair-and-apo-66b52868672bb1c984d1f3d5), where [ContextualAI/ultrafeedback_clair_32k](https://huggingface.co/datasets/ContextualAI/ultrafeedback_clair_32k) corresponds to the CLAIR implementation. | ||
|
||
### Replication | ||
|
||
!!! NOTE | ||
The section is named `Replication` but in this case we are showing how to use the [`CLAIR`][distilabel.steps.tasks.clair.CLAIR] task create revisions for your generations using `distilabel`. | ||
|
||
To showcase CLAIR we will be using the [`CLAIR`][distilabel.steps.tasks.PrometheusEval] task implemented in `distilabel` and we are reusing a small sample of the already generated dataset by ContextualAI [`ContextualAI/ultrafeedback_clair_32k`](https://huggingface.co/datasets/ContextualAI/ultrafeedback_clair_32k) for testing. | ||
|
||
#### Installation | ||
|
||
To reproduce the code below, one will need to install `distilabel` as follows: | ||
|
||
```bash | ||
pip install "distilabel>=1.4.0" | ||
``` | ||
|
||
Depending on the LLM provider you want to use, the requirements may vary, take a look at the dependencies in that case, we are using for the example the free inference endpoints from Hugging Face, but that won't apply for a bigger dataset. | ||
|
||
#### Building blocks | ||
|
||
In this case where we already have instructions and their generations, we will just need to load the data and the corresponding CLAIR task for the revisions: | ||
|
||
- [`CLAIR`](https://distilabel.argilla.io/dev/components-gallery/tasks/clair/) to generate the revisions. | ||
|
||
#### Code | ||
|
||
Let's see the full pipeline applied to `ContextualAI/ultrafeedback_clair_32k` in `distilabel`: | ||
|
||
```python | ||
from typing import Any, Dict | ||
|
||
from datasets import load_dataset | ||
|
||
from distilabel.pipeline import Pipeline | ||
from distilabel.steps.tasks import CLAIR | ||
from distilabel.llms import InferenceEndpointsLLM | ||
|
||
|
||
def transform_ultrafeedback(example: Dict[str, Any]) -> Dict[str, Any]: | ||
return { | ||
"task": example["prompt"], | ||
"student_solution": example["rejected"][1]["content"], | ||
} | ||
|
||
dataset = ( | ||
load_dataset("ContextualAI/ultrafeedback_clair_32k", split="train") | ||
.select(range(10)) # We collect just 10 examples | ||
.map(transform_ultrafeedback) # Apply the transformation to get just the text | ||
) | ||
|
||
with Pipeline(name="CLAIR UltraFeedback sample") as pipeline: | ||
clair = CLAIR( # (1) | ||
llm=InferenceEndpointsLLM( | ||
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", | ||
generation_kwargs={ | ||
"temperature": 0.7, | ||
"max_new_tokens": 4096 | ||
} | ||
) | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
distiset = pipeline.run(dataset=dataset) # (2) | ||
distiset.push_to_hub(repo_id="username/clair-test", include_script=True) # (3) | ||
``` | ||
|
||
1. This Pipeline uses just CLAIR because we already have the generations, but one can just include a first task to create generations from instructions, and then the revisions with CLAIR. | ||
|
||
2. Include the dataset directly in the run method for simplicity. | ||
|
||
3. Push the distiset to the hub with the script for reproducibility. | ||
|
||
An example dataset can be found at: [distilabel-internal-testing/clair-test](https://huggingface.co/datasets/distilabel-internal-testing/clair-test). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,199 @@ | ||
# Copyright 2023-present, Argilla, Inc. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import importlib.resources as importlib_resources | ||
from typing import TYPE_CHECKING, Any, Dict, Final, Union | ||
|
||
from jinja2 import Template | ||
from pydantic import PrivateAttr | ||
|
||
from distilabel.steps.tasks.base import Task | ||
|
||
if TYPE_CHECKING: | ||
from distilabel.steps.tasks.typing import ChatType | ||
from distilabel.steps.typing import StepColumns | ||
|
||
|
||
SYSTEM_PROMPT: Final[str] = ( | ||
"You are a teacher and your task is to minimally improve a student's answer. I will give you a {task} and a {student_solution}. Your job is to revise the {student_solution} such that it is clearer, more correct, and more engaging. Copy all non-corrected parts of the student's answer. Do not allude to the {corrected_student_solution} being a revision or a correction in your final solution." | ||
) | ||
|
||
|
||
class CLAIR(Task): | ||
r"""Contrastive Learning from AI Revisions (CLAIR). | ||
CLAIR uses an AI system to minimally revise a solution A→A´ such that the resulting | ||
preference A `preferred` A’ is much more contrastive and precise. | ||
Input columns: | ||
- task (`str`): The task or instruction. | ||
- student_solution (`str`): An answer to the task that is to be revised. | ||
Output columns: | ||
- revision (`str`): The revised text. | ||
- rational (`str`): The rational for the provided revision. | ||
- model_name (`str`): The name of the model used to generate the revision and rational. | ||
Categories: | ||
- preference | ||
- text-generation | ||
References: | ||
- [`Anchored Preference Optimization and Contrastive Revisions: Addressing Underspecification in Alignment`](https://arxiv.org/abs/2408.06266v1) | ||
- [`APO and CLAIR - GitHub Repository`](https://github.com/ContextualAI/CLAIR_and_APO) | ||
Examples: | ||
Create contrastive preference pairs: | ||
```python | ||
from distilabel.steps.tasks import CLAIR | ||
from distilabel.llms.huggingface import InferenceEndpointsLLM | ||
llm=InferenceEndpointsLLM( | ||
model_id="meta-llama/Meta-Llama-3.1-70B-Instruct", | ||
tokenizer_id="meta-llama/Meta-Llama-3.1-70B-Instruct", | ||
generation_kwargs={ | ||
"temperature": 0.7, | ||
"max_new_tokens": 4096, | ||
}, | ||
) | ||
clair_task = CLAIR(llm=llm) | ||
clair_task.load() | ||
result = next( | ||
clair_task.process( | ||
[ | ||
{ | ||
"task": "How many gaps are there between the earth and the moon?", | ||
"student_solution": 'There are no gaps between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon's orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range.\n\nSo, to summarize, there are no gaps between the Earth and the Moon. The Moon is simply a satellite that orbits the Earth, and its distance from our planet varies slightly due to the elliptical shape of its orbit.' | ||
} | ||
] | ||
) | ||
) | ||
# result | ||
# [{'task': 'How many gaps are there between the earth and the moon?', | ||
# 'student_solution': 'There are no gaps between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range.\n\nSo, to summarize, there are no gaps between the Earth and the Moon. The Moon is simply a satellite that orbits the Earth, and its distance from our planet varies slightly due to the elliptical shape of its orbit.', | ||
# 'revision': 'There are no physical gaps or empty spaces between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a significant separation or gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range. This variation in distance is a result of the Moon\'s orbital path, not the presence of any gaps.\n\nIn summary, the Moon\'s orbit is continuous, with no intervening gaps, and its distance from the Earth varies due to the elliptical shape of its orbit.', | ||
# 'rational': 'The student\'s solution provides a clear and concise answer to the question. However, there are a few areas where it can be improved. Firstly, the term "gaps" can be misleading in this context. The student should clarify what they mean by "gaps." Secondly, the student provides some additional information about the Moon\'s orbit, which is correct but could be more clearly connected to the main point. Lastly, the student\'s conclusion could be more concise.', | ||
# 'distilabel_metadata': {'raw_output_c_l_a_i_r_0': '{teacher_reasoning}: The student\'s solution provides a clear and concise answer to the question. However, there are a few areas where it can be improved. Firstly, the term "gaps" can be misleading in this context. The student should clarify what they mean by "gaps." Secondly, the student provides some additional information about the Moon\'s orbit, which is correct but could be more clearly connected to the main point. Lastly, the student\'s conclusion could be more concise.\n\n{corrected_student_solution}: There are no physical gaps or empty spaces between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a significant separation or gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range. This variation in distance is a result of the Moon\'s orbital path, not the presence of any gaps.\n\nIn summary, the Moon\'s orbit is continuous, with no intervening gaps, and its distance from the Earth varies due to the elliptical shape of its orbit.', | ||
# 'raw_input_c_l_a_i_r_0': [{'role': 'system', | ||
# 'content': "You are a teacher and your task is to minimally improve a student's answer. I will give you a {task} and a {student_solution}. Your job is to revise the {student_solution} such that it is clearer, more correct, and more engaging. Copy all non-corrected parts of the student's answer. Do not allude to the {corrected_student_solution} being a revision or a correction in your final solution."}, | ||
# {'role': 'user', | ||
# 'content': '{task}: How many gaps are there between the earth and the moon?\n\n{student_solution}: There are no gaps between the Earth and the Moon. The Moon is actually in a close orbit around the Earth, and it is held in place by gravity. The average distance between the Earth and the Moon is about 384,400 kilometers (238,900 miles), and this distance is known as the "lunar distance" or "lunar mean distance."\n\nThe Moon does not have a gap between it and the Earth because it is a natural satellite that is gravitationally bound to our planet. The Moon\'s orbit is elliptical, which means that its distance from the Earth varies slightly over the course of a month, but it always remains within a certain range.\n\nSo, to summarize, there are no gaps between the Earth and the Moon. The Moon is simply a satellite that orbits the Earth, and its distance from our planet varies slightly due to the elliptical shape of its orbit.\n\n-----------------\n\nLet\'s first think step by step with a {teacher_reasoning} to decide how to improve the {student_solution}, then give the {corrected_student_solution}. Mention the {teacher_reasoning} and {corrected_student_solution} identifiers to structure your answer.'}]}, | ||
# 'model_name': 'meta-llama/Meta-Llama-3.1-70B-Instruct'}] | ||
``` | ||
Citations: | ||
``` | ||
@misc{doosterlinck2024anchoredpreferenceoptimizationcontrastive, | ||
title={Anchored Preference Optimization and Contrastive Revisions: Addressing Underspecification in Alignment}, | ||
author={Karel D'Oosterlinck and Winnie Xu and Chris Develder and Thomas Demeester and Amanpreet Singh and Christopher Potts and Douwe Kiela and Shikib Mehri}, | ||
year={2024}, | ||
eprint={2408.06266}, | ||
archivePrefix={arXiv}, | ||
primaryClass={cs.LG}, | ||
url={https://arxiv.org/abs/2408.06266}, | ||
} | ||
``` | ||
""" | ||
|
||
system_prompt: str = SYSTEM_PROMPT | ||
_template: Union[Template, None] = PrivateAttr(...) | ||
|
||
def load(self) -> None: | ||
super().load() | ||
_path = str( | ||
importlib_resources.files("distilabel") | ||
/ "steps" | ||
/ "tasks" | ||
/ "templates" | ||
/ "clair.jinja2" | ||
) | ||
with open(_path, "r") as f: | ||
self._template = Template(f.read()) | ||
|
||
@property | ||
def inputs(self) -> "StepColumns": | ||
return ["task", "student_solution"] | ||
|
||
@property | ||
def outputs(self) -> "StepColumns": | ||
return ["revision", "rational", "model_name"] | ||
|
||
def format_input(self, input: Dict[str, Any]) -> "ChatType": | ||
"""The input is formatted as a `ChatType` assuming that the instruction | ||
is the first interaction from the user within a conversation.""" | ||
return [ | ||
{"role": "system", "content": self.system_prompt}, | ||
{ | ||
"role": "user", | ||
"content": self._template.render( | ||
task=input["task"], student_solution=input["student_solution"] | ||
), | ||
}, | ||
] | ||
|
||
def format_output( | ||
self, output: Union[str, None], input: Dict[str, Any] | ||
) -> Dict[str, Any]: | ||
"""The output is formatted as a list with the score of each instruction-response pair. | ||
Args: | ||
output: the raw output of the LLM. | ||
input: the input to the task. Used for obtaining the number of responses. | ||
Returns: | ||
A dict with the key `scores` containing the scores for each instruction-response pair. | ||
""" | ||
if output is None: | ||
return self._default_error() | ||
|
||
return self._format_output(output) | ||
|
||
def _format_output(self, output: Union[str, None]) -> Dict[str, Any]: | ||
if "**Corrected Student Solution:**" in output: | ||
splits = output.split("**Corrected Student Solution:**") | ||
elif "{corrected_student_solution}:" in output: | ||
splits = output.split("{corrected_student_solution}:") | ||
elif "{corrected_student_solution}" in output: | ||
splits = output.split("{corrected_student_solution}") | ||
elif "**Worsened Student Solution:**" in output: | ||
splits = output.split("**Worsened Student Solution:**") | ||
elif "{worsened_student_solution}:" in output: | ||
splits = output.split("{worsened_student_solution}:") | ||
elif "{worsened_student_solution}" in output: | ||
splits = output.split("{worsened_student_solution}") | ||
else: | ||
splits = None | ||
|
||
# Safety check when the output doesn't follow the expected format | ||
if not splits: | ||
return self._default_error() | ||
|
||
if len(splits) >= 2: | ||
revision = splits[1] | ||
revision = revision.strip("\n\n").strip() # noqa: B005 | ||
|
||
rational = splits[0] | ||
if "{teacher_reasoning}" in rational: | ||
rational = rational.split("{teacher_reasoning}")[1].strip(":").strip() | ||
rational = rational.strip("\n\n").strip() # noqa: B005 | ||
else: | ||
return self._default_error() | ||
return {"revision": revision, "rational": rational} | ||
|
||
def _default_error(self) -> Dict[str, None]: | ||
return {"revision": None, "rational": None} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
{task}: {{ task }} | ||
|
||
{student_solution}: {{ student_solution }} | ||
|
||
----------------- | ||
|
||
Let's first think step by step with a {teacher_reasoning} to decide how to improve the {student_solution}, then give the {corrected_student_solution}. Mention the {teacher_reasoning} and {corrected_student_solution} identifiers to structure your answer. |
Oops, something went wrong.