Skip to content

Commit

Permalink
Remove use of default_chat_template (#888)
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb authored Aug 12, 2024
1 parent c006ddc commit bbe04fd
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 27 deletions.
15 changes: 0 additions & 15 deletions src/distilabel/llms/chat_templates.py

This file was deleted.

9 changes: 3 additions & 6 deletions src/distilabel/llms/huggingface/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from pydantic import Field, PrivateAttr, SecretStr, validate_call

from distilabel.llms.base import LLM
from distilabel.llms.chat_templates import CHATML_TEMPLATE
from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
from distilabel.llms.typing import GenerateOutput
Expand Down Expand Up @@ -145,11 +144,6 @@ def load(self) -> None:

if self.chat_template is not None:
self._pipeline.tokenizer.chat_template = self.chat_template # type: ignore
elif (
self._pipeline.tokenizer.chat_template is None # type: ignore
and self._pipeline.tokenizer.default_chat_template is None # type: ignore
):
self._pipeline.tokenizer.chat_template = CHATML_TEMPLATE # type: ignore

if self.structured_output:
self._prefix_allowed_tokens_fn = self._prepare_structured_output(
Expand Down Expand Up @@ -178,6 +172,9 @@ def prepare_input(self, input: "StandardInput") -> str:
Returns:
The prompt to send to the LLM.
"""
if self._pipeline.tokenizer.chat_template: # type: ignore
return input[0]["content"]

prompt: str = (
self._pipeline.tokenizer.apply_chat_template( # type: ignore
input, # type: ignore
Expand Down
9 changes: 3 additions & 6 deletions src/distilabel/llms/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from pydantic import Field, PrivateAttr, SecretStr, validate_call

from distilabel.llms.base import LLM
from distilabel.llms.chat_templates import CHATML_TEMPLATE
from distilabel.llms.mixins.cuda_device_placement import CudaDevicePlacementMixin
from distilabel.llms.mixins.magpie import MagpieChatTemplateMixin
from distilabel.llms.openai import OpenAILLM
Expand Down Expand Up @@ -204,11 +203,6 @@ def load(self) -> None:
self._tokenizer = self._model.get_tokenizer() # type: ignore
if self.chat_template is not None:
self._tokenizer.chat_template = self.chat_template # type: ignore
elif (
self._tokenizer.chat_template is None # type: ignore
and self._tokenizer.default_chat_template is None # type: ignore
):
self._tokenizer.chat_template = CHATML_TEMPLATE

if self.structured_output:
self._logits_processor = self._prepare_structured_output(
Expand All @@ -235,6 +229,9 @@ def prepare_input(self, input: "StandardInput") -> str:
Returns:
The prompt to send to the LLM.
"""
if self._tokenizer.chat_template is None:
return input[0]["content"]

prompt: str = (
self._tokenizer.apply_chat_template(
input, # type: ignore
Expand Down

0 comments on commit bbe04fd

Please sign in to comment.