Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compare jcommonsense qa prompts with question first vs last #113

Open
wants to merge 6 commits into
base: jp-stable
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lm_eval/models/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ def __init__(
low_cpu_mem_usage=None,
torch_dtype=None,
device_map=None,
offload_folder=None,
subfolder=None,
tokenizer=None,
batch_size=1,
load_in_8bit: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
use_fast: Optional[bool] = True,
additional_special_tokens: Optional[str] = None,
):
super().__init__()

Expand Down Expand Up @@ -49,6 +51,7 @@ def __init__(
low_cpu_mem_usage=low_cpu_mem_usage,
torch_dtype=torch_dtype,
device_map=device_map,
offload_folder=offload_folder,
revision=revision,
trust_remote_code=trust_remote_code,
).eval()
Expand All @@ -64,6 +67,7 @@ def __init__(
revision=revision,
trust_remote_code=trust_remote_code,
use_fast=use_fast,
additional_special_tokens=additional_special_tokens,
)
self.vocab_size = self.tokenizer.vocab_size

Expand Down
128 changes: 128 additions & 0 deletions lm_eval/tasks/ja/jcommonsenseqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,24 @@ def doc_to_text(self, doc):
return input_text


class JCommonsenseQAWithFintanPromptV22(JCommonsenseQAWithFintanPromptV21):
PROMPT_VERSION = "0.2.2"

def doc_to_text(self, doc):
"""
与えられた選択肢の中から、最適な答えを選んでください。

選択肢:
- {choice0}
- {choice4}
質問:{question}
回答:
"""
choices = "\n".join([f"- {choice}" for choice in doc["choices"]])
input_text = f"選択肢:\n{choices}\n質問:{doc['goal']}\n回答:" # question last
return input_text


class JCommonsenseQAWithJAAlpacaPrompt(JCommonsenseQA):
"""
This prompt format was inspired by the below data in fujiki/japanese_alpaca_data.
Expand Down Expand Up @@ -205,6 +223,42 @@ def doc_to_text(self, doc):
return f"### 指示:\n{instruction_text}\n\n### 入力:\n{input_text}\n\n### 応答:\n"


class JCommonsenseQAWithJAAlpacaPromptV32(JCommonsenseQAWithJAAlpacaPrompt):
"""
This prompt format was inspired by the below data in fujiki/japanese_alpaca_data.
```
{
'instruction': 'この課題では、以下の選択肢から文の出典を特定する必要があります。\n\n出力は以下から選択してください:\n- 新聞\n- 教科書\n- オンライン記事\n- 百科事典',
'input': '彼はローマの政治家であり哲学者であり、史上最も偉大な軍事指導者の一人と考えられています。',
'output': '百科事典'
}
```
Reference:
- data: https://huggingface.co/datasets/fujiki/japanese_alpaca_data
- code: https://github.com/Stability-AI/gpt-neox/blob/c130a4edc1120dccec8f02a34eb60d3e8f484cd3/finetune/finetune_base_ja.py#LL118C23-L127C11
"""

PROMPT_VERSION = "0.3.2"

def doc_to_text(self, doc):
"""
以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。

### 指示:
{instruction}

### 入力:
{input}

### 応答:
{response}
"""
instruction_text = self.INSTRUCTION + f"\n質問:{doc['goal']}"
choices = "\n".join([f"- {choice}" for choice in doc["choices"]])
input_text = f"出力は以下から選択してください:\n{choices}"
return f"### 指示:\n{instruction_text}\n\n### 入力:\n{input_text}\n\n### 応答:\n" # question first


class JCommonsenseQAWithRinnaInstructionSFT(JCommonsenseQA):
"""
Reference:
Expand All @@ -223,6 +277,22 @@ def doc_to_text(self, doc):
return f"ユーザー: {input_text}{self.SEP}システム: "


class JCommonsenseQAWithRinnaInstructionSFTV42(JCommonsenseQAWithRinnaInstructionSFT):
"""
Reference:
- HF Hub: https://huggingface.co/rinna/japanese-gpt-neox-3.6b-instruction-sft
"""

PROMPT_VERSION = "0.4.2"

def doc_to_text(self, doc):
choices = self.SEP.join([f"- {choice}" for choice in doc["choices"]])
input_text = (
f"選択肢:{self.SEP}{choices}" + f"質問:{doc['goal']}{self.SEP}"
) # question last
return f"ユーザー: {input_text}{self.SEP}システム: "


class JCommonsenseQAWithRinnaBilingualInstructionSFT(
JCommonsenseQAWithRinnaInstructionSFT
):
Expand All @@ -237,6 +307,24 @@ class JCommonsenseQAWithRinnaBilingualInstructionSFT(
FEWSHOT_SEP = "\n"


class JCommonsenseQAWithRinnaBilingualInstructionSFTV52(
JCommonsenseQAWithRinnaBilingualInstructionSFT
):
"""
Reference:
- HF Hub: https://huggingface.co/rinna/bilingual-gpt-neox-4b-instruction-sft
"""

PROMPT_VERSION = "0.5.2"

def doc_to_text(self, doc):
choices = self.SEP.join([f"- {choice}" for choice in doc["choices"]])
input_text = (
f"選択肢:{self.SEP}{choices}" + f"質問:{doc['goal']}{self.SEP}"
) # question last
return f"ユーザー: {input_text}{self.SEP}システム: "


class JCommonsenseQAWithLlama2(JCommonsenseQA):
"""
This prompt version follows the Llama2-chat's prompt format:
Expand All @@ -262,6 +350,7 @@ def doc_to_text(self, doc):
Insert the following prompt into `{{ user_msg }}`, which is based on prompt version 0.3
```
与えられた選択肢の中から、最適な答えを選んでください。出力は以下から選択してください:

- choice0
...
- choice4
Expand All @@ -275,14 +364,53 @@ def doc_to_text(self, doc):
return f"{instruction_text}\n\n{input_text} [/INST] "


class JCommonsenseQAWithLlama2V62(JCommonsenseQAWithLlama2):
"""
This prompt version follows the Llama2-chat's prompt format:
```
<s>[INST] <<SYS>>
{{ system_prompt }}
<</SYS>>

{{ user_msg_1 }} [/INST] {{ model_answer_1 }} </s><s>[INST] {{ user_msg_2 }} [/INST]
```
reference: https://huggingface.co/blog/llama2#how-to-prompt-llama-2
"""

PROMPT_VERSION = "0.6.2"

def doc_to_text(self, doc):
"""
Insert the following prompt into `{{ user_msg }}`, which is based on prompt version 0.3
```
与えられた選択肢の中から、最適な答えを選んでください。質問:...

出力は以下から選択してください:
- choice0
...
- choice4 [/INST]
```
"""
choices = "\n".join([f"- {choice}" for choice in doc["choices"]])
input_text = f"質問:{doc['goal']}"
instruction_text = self.INSTRUCTION + input_text
choices = f"出力は以下から選択してください:\n{choices}"
return f"{instruction_text}\n\n{choices} [/INST] " # question first


VERSIONS = [
JCommonsenseQA,
JCommonsenseQAWithFintanPrompt,
JCommonsenseQAWithFintanPromptV21,
JCommonsenseQAWithFintanPromptV22,
JCommonsenseQAWithJAAlpacaPrompt,
JCommonsenseQAWithJAAlpacaPromptV32,
JCommonsenseQAWithRinnaInstructionSFT,
JCommonsenseQAWithRinnaInstructionSFTV42,
JCommonsenseQAWithRinnaBilingualInstructionSFT,
JCommonsenseQAWithRinnaBilingualInstructionSFTV52,
JCommonsenseQAWithLlama2,
JCommonsenseQAWithLlama2V62,
]


Expand Down
Loading