diff --git a/src/distilabel/llms/chat_templates.py b/src/distilabel/llms/chat_templates.py deleted file mode 100644 index 7edba0132c..0000000000 --- a/src/distilabel/llms/chat_templates.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright 2023-present, Argilla, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -CHATML_TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message[\"content\"] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}" diff --git a/src/distilabel/llms/huggingface/transformers.py b/src/distilabel/llms/huggingface/transformers.py index 6b8ad25e2f..455e6e898b 100644 --- a/src/distilabel/llms/huggingface/transformers.py +++ b/src/distilabel/llms/huggingface/transformers.py @@ -18,7 +18,6 @@ from pydantic import Field, PrivateAttr, SecretStr, validate_call from distilabel.llms.base import LLM -from distilabel.llms.chat_templates import CHATML_TEMPLATE from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin from distilabel.llms.typing import GenerateOutput @@ -145,11 +144,6 @@ def load(self) -> None: if self.chat_template is not None: self._pipeline.tokenizer.chat_template = self.chat_template # type: ignore - elif ( - self._pipeline.tokenizer.chat_template is None # type: ignore - and self._pipeline.tokenizer.default_chat_template is None # type: ignore - ): - self._pipeline.tokenizer.chat_template = CHATML_TEMPLATE # type: ignore if self.structured_output: self._prefix_allowed_tokens_fn = self._prepare_structured_output( @@ -178,6 +172,9 @@ def prepare_input(self, input: "StandardInput") -> str: Returns: The prompt to send to the LLM. """ + if self._pipeline.tokenizer.chat_template: # type: ignore + return input[0]["content"] + prompt: str = ( self._pipeline.tokenizer.apply_chat_template( # type: ignore input, # type: ignore diff --git a/src/distilabel/llms/vllm.py b/src/distilabel/llms/vllm.py index c6351b16bb..4ff30c07f4 100644 --- a/src/distilabel/llms/vllm.py +++ b/src/distilabel/llms/vllm.py @@ -30,7 +30,6 @@ from pydantic import Field, PrivateAttr, SecretStr, validate_call from distilabel.llms.base import LLM -from distilabel.llms.chat_templates import CHATML_TEMPLATE from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin from distilabel.llms.openai import OpenAILLM @@ -204,11 +203,6 @@ def load(self) -> None: self._tokenizer = self._model.get_tokenizer() # type: ignore if self.chat_template is not None: self._tokenizer.chat_template = self.chat_template # type: ignore - elif ( - self._tokenizer.chat_template is None # type: ignore - and self._tokenizer.default_chat_template is None # type: ignore - ): - self._tokenizer.chat_template = CHATML_TEMPLATE if self.structured_output: self._logits_processor = self._prepare_structured_output( @@ -235,6 +229,9 @@ def prepare_input(self, input: "StandardInput") -> str: Returns: The prompt to send to the LLM. """ + if self._tokenizer.chat_template is None: + return input[0]["content"] + prompt: str = ( self._tokenizer.apply_chat_template( input, # type: ignore