Skip to content

Commit

Permalink
Fix to automatically infer add_special_tokens for tokenizer (#370)
Browse files Browse the repository at this point in the history
* fix(*): fix to infer add_special_tokens for tokenizer

* fix(set_tokenizer): check if the first token is bos to set add_special_tokens
  • Loading branch information
soheeyang authored Aug 20, 2023
1 parent e956cba commit 54d548d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
4 changes: 1 addition & 3 deletions tests/unit/test_prepend_bos.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ def get_num_tokens_in_prompt(self, model, prompt, intended_prepend_bos):
# copied from HookedTransformer.to_tokens()
tokens = tokenizer(
prompt,
add_special_tokens=False
if model.tokenizer.name_or_path.startswith("facebook/opt")
else True, # As we manually add the BOS token
add_special_tokens=model.cfg.add_special_tokens,
)["input_ids"]

return len(tokens) + int(intended_prepend_bos)
Expand Down
14 changes: 11 additions & 3 deletions transformer_lens/HookedTransformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,13 @@ def set_tokenizer(self, tokenizer):
if self.cfg.d_vocab_out == -1:
self.cfg.d_vocab_out = self.cfg.d_vocab

# If the tokenizer prepends the BOS token to the input by default, turn it off.
# We manually control whether or not to prepend BOS tokens.
self.cfg.add_special_tokens = not (
len(self.tokenizer("")["input_ids"]) > 0
and self.tokenizer("")["input_ids"][0] == self.tokenizer.bos_token_id
)

def to_tokens(
self,
input: Union[str, List[str]],
Expand Down Expand Up @@ -497,6 +504,9 @@ def to_tokens(
capitalized. It's easy to shoot yourself in the foot here if you're not careful!
"""
assert self.tokenizer is not None, "Cannot use to_tokens without a tokenizer"
assert (
self.cfg.add_special_tokens is not None
), "Set the tokenizer for the model by calling set_tokenizer"

# Use the provided prepend_bos as an override if it's not None;
# otherwise use self.cfg.default_prepend_bos (defaults to True unless specified otherwise)
Expand All @@ -515,9 +525,7 @@ def to_tokens(
padding=True,
truncation=truncate,
max_length=self.cfg.n_ctx if truncate else None,
add_special_tokens=False
if self.tokenizer.name_or_path.startswith("facebook/opt")
else True, # As we manually add the BOS token
add_special_tokens=self.cfg.add_special_tokens,
)["input_ids"]
if move_to_device:
tokens = tokens.to(self.cfg.device)
Expand Down
1 change: 1 addition & 0 deletions transformer_lens/HookedTransformerConfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ class HookedTransformerConfig:
gated_mlp: bool = False
default_prepend_bos: bool = True
dtype: torch.dtype = torch.float32
add_special_tokens: Optional[bool] = None # will be set by set_tokenizer

def __post_init__(self):
if self.n_heads == -1:
Expand Down

0 comments on commit 54d548d

Please sign in to comment.