From 746a8f68fbb48e61f090c2c02cf04e6c6701b7b9 Mon Sep 17 00:00:00 2001 From: ylfeng Date: Wed, 18 Sep 2024 20:35:39 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=20MistralTool=20=E6=A0=BC?= =?UTF-8?q?=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/llamafactory/data/formatter.py | 10 +- src/llamafactory/data/template.py | 2 +- src/llamafactory/data/tool_utils.py | 4 +- tests/data/test_template.py | 194 +++++++++++++++++++++++++++- 4 files changed, 199 insertions(+), 11 deletions(-) diff --git a/src/llamafactory/data/formatter.py b/src/llamafactory/data/formatter.py index 0efd267148..3c6b72490e 100644 --- a/src/llamafactory/data/formatter.py +++ b/src/llamafactory/data/formatter.py @@ -147,7 +147,7 @@ def apply(self, **kwargs) -> SLOTS: elements = [] for name, arguments in functions: - elements.append(f""""{{"name":"{name}","arguments":{arguments}}}""") + elements.append(f"""{{"name": "{name}", "arguments": {arguments}}}""") elements = ["[TOOL_CALLS] [" + ", ".join(elements) + "]"] return elements @@ -163,14 +163,14 @@ def apply(self, **kwargs) -> SLOTS: content = kwargs.pop("content") tool_results: List[Tuple[str, str]] try: - tool_results = [json.dumps(result) for result in json.loads(content)] + tool_results = json.loads(content) except json.JSONDecodeError: tool_results = [] elements = [] - for content in tool_results: - elements.append(f"[TOOL_RESULTS] {{\"content\":{content}}}[/TOOL_RESULTS]") - return ["".join(elements)] + for result in tool_results: + elements.append(f"[TOOL_RESULTS] {{\"content\": {result}}}[/TOOL_RESULTS]") + return elements @dataclass diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 89d19be01a..c65d2972c9 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -723,7 +723,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: _register_template( name="mistral", - format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), + format_user=StringFormatter(slots=["[INST] {{content}}[/INST]"]), format_assistant=StringFormatter(slots=[" {{content}}"]), # mistral add space here format_prefix=EmptyFormatter(slots=[{"bos_token"}]), format_function=MistralFunctionFormatter(slots=[], tool_format="mistral"), diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py index 8ad284bd1b..ce32c63bce 100644 --- a/src/llamafactory/data/tool_utils.py +++ b/src/llamafactory/data/tool_utils.py @@ -38,7 +38,7 @@ "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}" ) -MISTRAL_TOOL_PROMPT = "[AVAILABLE_TOOLS] {tools} [/AVAILABLE_TOOLS]" +MISTRAL_TOOL_PROMPT = "[AVAILABLE_TOOLS] {tools}[/AVAILABLE_TOOLS]" FunctionCall = namedtuple("FunctionCall", ["name", "arguments"]) @@ -176,7 +176,7 @@ def get_function_slots() -> SLOTS: @override @staticmethod def tool_formatter(tools: List[Dict[str, Any]]) -> str: - tools = [{"type": "function", "function": tool} for tool in tools] + tools = json.dumps([{"type": "function", "function": tool} for tool in tools],ensure_ascii=False) return MISTRAL_TOOL_PROMPT.format(tools=tools) @override diff --git a/tests/data/test_template.py b/tests/data/test_template.py index a327df22b8..717b320d7e 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import json import os from typing import TYPE_CHECKING, List, Sequence @@ -21,11 +21,9 @@ from llamafactory.data import get_template_and_fix_tokenizer from llamafactory.hparams import DataArguments - if TYPE_CHECKING: from transformers import PreTrainedTokenizer - HF_TOKEN = os.environ.get("HF_TOKEN", None) TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3") @@ -37,6 +35,81 @@ {"role": "assistant", "content": "很高兴认识你!"}, ] +TOOL_MESSAGES = { + "tools": [ + { + "type": "function", + "function": { + "name": "get_news", + "description": "获取最新新闻文章", + "parameters": { + "type": "object", + "properties": { + "category": {"type": "string", "description": "要检索的新闻文章类别"}, + "country": {"type": "string", "description": "获取新闻文章的国家"} + }, + "required": ["category"] + } + } + }, + { + "type": "function", + "function": { + "name": "search_books", + "description": "根据提供的标准搜索书籍", + "parameters": { + "type": "object", + "properties": { + "title": {"type": "string", "description": "这本书的标题"}, + "author": {"type": "string", "description": "这本书的作者"}, + "genre": {"type": "string", "description": "这本书的类型"} + } + } + } + } + ], + "messages": [ + { + "role": "user", + "content": "你能帮我找到最新的美国体育新闻吗?" + }, + { + "role": "tool_calls", + "content": [ + { + "type": "function", + "function": {"name": "get_news", "arguments": {"category": "运动", "country": "美国"}} + } + ] + }, + { + "role": "tool", + "content": json.dumps( + {"title": "NBA总决赛:湖人队对阵热火队", "link": "NBA官方网站"}, + ensure_ascii=False + ), + }, + { + "role": "tool", + "content": json.dumps( + {"title": "NFL:爱国者队击败酋长队", "link": "https://www.nfl.com/新闻"}, + ensure_ascii=False + ), + }, + { + "role": "tool", + "content": json.dumps( + {"title": "MLB:道奇队赢得世界系列赛", "link": "https://www.mlb.com/新闻"}, + ensure_ascii=False + ) + }, + { + "role": "assistant", + "content": "1. NBA总决赛:湖人队对阵热火队\n2. NFL:爱国者队击败酋长队\n3. MLB:道奇队赢得世界系列赛" + } + ], +} + def _check_tokenization( tokenizer: "PreTrainedTokenizer", batch_input_ids: Sequence[Sequence[int]], batch_text: Sequence[str] @@ -168,3 +241,118 @@ def test_yi_template(): ) answer_str = "很高兴认识你!<|im_end|>" _check_template("01-ai/Yi-1.5-6B-Chat", "yi", prompt_str, answer_str) + + +@pytest.mark.xfail(reason="The fast tokenizer of mistral model is corrupted.") +def test_mistral_template(): + TEMPLATE = r""" +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} +{%- set user_messages = messages | selectattr("role", "equalto", "user") | list %} + +{%- for message in lmessages | rejectattr("role", "equalto", "tool") | rejectattr("role", "equalto", "tool_results") | selectattr("tool_calls", "undefined") %} + {%- if (message["role"] == "user") != (loop.index0 % 2 == 0) %} + {{- raise_exception("Conversation roles must alternate user/assistant/user/assistant/...") }} + {%- endif %} +{%- endfor %} + +{{- bos_token }} +{%- for message in messages %} + {%- if message["role"] == "user" %} + {%- if tools is not none and (message == user_messages[-1]) %} + {{- "[AVAILABLE_TOOLS] [" }} + {%- for tool in tools %} + {%- set tool = tool.function %} + {{- '{"type": "function", "function": {' }} + {%- for key, val in tool.items() if key != "return" %} + {%- if val is string %} + {{- '"' + key + '": "' + val + '"' }} + {%- else %} + {{- '"' + key + '": ' + val|tojson }} + {%- endif %} + {%- if not loop.last %} + {{- ", " }} + {%- endif %} + {%- endfor %} + {{- "}}" }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- "[/AVAILABLE_TOOLS]" }} + {%- endif %} + {{- "[INST] " + message["content"] + "[/INST]" }} + {%- elif message["role"] == "tool_calls" or message.tool_calls is defined %} + {%- if message.tool_calls is defined %} + {%- set tool_calls = message.tool_calls %} + {%- else %} + {%- set tool_calls = message.content %} + {%- endif %} + {{- "[TOOL_CALLS] [" }} + {%- for tool_call in tool_calls %} + {%- set out = tool_call.function|tojson %} + {{- out }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {%- elif message["role"] == "assistant" %} + {{- " " + message["content"] }} + {%- elif message["role"] == "tool_results" or message["role"] == "tool" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {{- '[TOOL_RESULTS] {"content": ' + content|string + "}[/TOOL_RESULTS]" }} + {%- else %} + {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %} +""" + tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained("/home/share/models/Mistral-7B-v0.3") + template = get_template_and_fix_tokenizer(tokenizer, DataArguments(template="mistral")) + + content_str = tokenizer.apply_chat_template( + conversation=TOOL_MESSAGES['messages'], + tools=TOOL_MESSAGES['tools'], + chat_template=TEMPLATE, + tokenize=False + ) + content_ids = tokenizer.apply_chat_template( + conversation=TOOL_MESSAGES['messages'], + tools=TOOL_MESSAGES['tools'], + chat_template=TEMPLATE, + tokenize=True + ) + encoded_pairs = template.encode_multiturn( + tokenizer, + [ + TOOL_MESSAGES['messages'][0], + { + "role": "function", + "content": json.dumps([function['function'] for function in TOOL_MESSAGES['messages'][1]['content']]) + }, + { + "role": "observation", + "content": json.dumps([item['content'] for item in TOOL_MESSAGES['messages'][2:-1]]) + }, + TOOL_MESSAGES['messages'][-1], + ], + tools=json.dumps([tool['function'] for tool in TOOL_MESSAGES['tools']]) + ) + + final_ids = [] + for prompt, response in encoded_pairs: + final_ids.extend(prompt) + final_ids.extend(response) + + final_str = tokenizer.decode(final_ids) + + assert content_str == final_str + assert content_ids == final_ids