Skip to content

Commit

Permalink
Add special_tokens_in_strings Arg
Browse files Browse the repository at this point in the history
  • Loading branch information
abuelnasr0 committed Mar 13, 2024
1 parent 0f4c93e commit d927a86
Show file tree
Hide file tree
Showing 8 changed files with 60 additions and 31 deletions.
5 changes: 5 additions & 0 deletions keras_nlp/models/bert/bert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ class BertTokenizer(WordPieceTokenizer):
plain text file containing a single word piece token per line.
lowercase: If `True`, the input text will be first lowered before
tokenization.
special_tokens_in_strings: bool. A bool to indicate if the tokenizer
should expect special tokens in input strings that should be
tokenized and mapped correctly to their ids. Defaults to False.
Examples:
```python
Expand Down Expand Up @@ -76,6 +79,7 @@ def __init__(
self,
vocabulary=None,
lowercase=False,
special_tokens_in_strings=False,
**kwargs,
):
self.cls_token = "[CLS]"
Expand All @@ -91,6 +95,7 @@ def __init__(
self.pad_token,
self.mask_token,
],
special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)

Expand Down
4 changes: 3 additions & 1 deletion keras_nlp/models/bert/bert_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def test_lowercase(self):

def test_tokenizer_special_tokens(self):
input_data = ["[CLS] THE [MASK] FOX [SEP] [PAD]"]
tokenizer = BertTokenizer(**self.init_kwargs)
tokenizer = BertTokenizer(
**self.init_kwargs, special_tokens_in_strings=True
)
output_data = tokenizer(input_data)
expected_output = [[2, 5, 4, 8, 3, 0]]

Expand Down
5 changes: 5 additions & 0 deletions keras_nlp/models/distil_bert/distil_bert_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ class DistilBertTokenizer(WordPieceTokenizer):
plain text file containing a single word piece token per line.
lowercase: If `True`, the input text will be first lowered before
tokenization.
special_tokens_in_strings: bool. A bool to indicate if the tokenizer
should expect special tokens in input strings that should be
tokenized and mapped correctly to their ids. Defaults to False.
Examples:
Expand Down Expand Up @@ -74,6 +77,7 @@ def __init__(
self,
vocabulary,
lowercase=False,
special_tokens_in_strings=False,
**kwargs,
):
self.cls_token = "[CLS]"
Expand All @@ -89,6 +93,7 @@ def __init__(
self.pad_token,
self.mask_token,
],
special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)

Expand Down
4 changes: 3 additions & 1 deletion keras_nlp/models/distil_bert/distil_bert_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ def test_lowercase(self):

def test_tokenizer_special_tokens(self):
input_data = ["[CLS] THE [MASK] FOX [SEP] [PAD]"]
tokenizer = DistilBertTokenizer(**self.init_kwargs)
tokenizer = DistilBertTokenizer(
**self.init_kwargs, special_tokens_in_strings=True
)
output_data = tokenizer(input_data)
expected_output = [[2, 5, 4, 8, 3, 0]]

Expand Down
5 changes: 5 additions & 0 deletions keras_nlp/models/electra/electra_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class ElectraTokenizer(WordPieceTokenizer):
plain text file containing a single word piece token per line.
lowercase: If `True`, the input text will be first lowered before
tokenization.
special_tokens_in_strings: bool. A bool to indicate if the tokenizer
should expect special tokens in input strings that should be
tokenized and mapped correctly to their ids. Defaults to False.
Examples:
```python
Expand All @@ -61,6 +64,7 @@ def __init__(
self,
vocabulary,
lowercase=False,
special_tokens_in_strings=False,
**kwargs,
):
self.cls_token = "[CLS]"
Expand All @@ -76,6 +80,7 @@ def __init__(
self.pad_token,
self.mask_token,
],
special_tokens_in_strings=special_tokens_in_strings,
**kwargs,
)

Expand Down
4 changes: 3 additions & 1 deletion keras_nlp/models/electra/electra_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def test_lowercase(self):

def test_tokenizer_special_tokens(self):
input_data = ["[CLS] THE [MASK] FOX [SEP] [PAD]"]
tokenizer = ElectraTokenizer(**self.init_kwargs)
tokenizer = ElectraTokenizer(
**self.init_kwargs, special_tokens_in_strings=True
)
output_data = tokenizer(input_data)
expected_output = [[2, 5, 4, 8, 3, 0]]

Expand Down
35 changes: 20 additions & 15 deletions keras_nlp/tokenizers/word_piece_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,12 +262,15 @@ class WordPieceTokenizer(tokenizer.Tokenizer):
oov_token: str. The string value to substitute for
an unknown token. It must be included in the vocab.
Defaults to `"[UNK]"`.
special_tokens: list. A list of strings that will never be split during
the word-level splitting applied before the word-peice encoding.
This can be used to ensure special tokens map to unique indices in
the vocabulary, even if these special tokens contain splittable
characters such as punctuation. Special tokens must still be
included in `vocabulary`. Defaults to `None`.
special_tokens: list. A list of special tokens. when
`special_tokens_in_strings` is set to `True`, the tokenizer will map
every special token in the input strings to its id, even if these
special tokens contain characters that should be splitted before
tokenization such as punctuation. `special_tokens` must be included
in `vocabulary`.
special_tokens_in_strings: bool. A bool to indicate if the tokenizer
should expect special tokens in input strings that should be
tokenized and mapped correctly to their ids. Defaults to False.
References:
- [Schuster and Nakajima, 2012](https://research.google/pubs/pub37842/)
Expand Down Expand Up @@ -347,6 +350,7 @@ def __init__(
suffix_indicator: str = "##",
oov_token: str = "[UNK]",
special_tokens: List[str] = None,
special_tokens_in_strings: bool = False,
dtype="int32",
**kwargs,
) -> None:
Expand All @@ -371,8 +375,8 @@ def __init__(
self.oov_token = oov_token
self.special_tokens = special_tokens
self._special_tokens_pattern = None
if self.split:
# Get the pattern of special tokens.
if self.split and special_tokens_in_strings:
# Use special tokens pattern to avoid sp.
# the idea here is to pass the special tokens regex to the
# split function as delimiter regex pattern, so the input will
# be splitted by them, but also the function will treat each on
Expand Down Expand Up @@ -424,13 +428,14 @@ def set_vocabulary(self, vocabulary):
)

# Check for special tokens in the vocabulary
for token in self.special_tokens:
if token not in self.vocabulary:
raise ValueError(
f"Cannot find token `'{token}'` in the provided "
f"`vocabulary`. Please provide `'{token}'` in your "
"`vocabulary` or use a pretrained `vocabulary` name."
)
if isinstance(self.special_tokens, Iterable):
for token in self.special_tokens:
if token not in self.vocabulary:
raise ValueError(
f"Cannot find token `'{token}'` in the provided "
f"`vocabulary`. Please provide `'{token}'` in your "
"`vocabulary` or use a pretrained `vocabulary` name."
)

self._fast_word_piece = tf_text.FastWordpieceTokenizer(
vocab=self.vocabulary,
Expand Down
29 changes: 16 additions & 13 deletions keras_nlp/tokenizers/word_piece_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_error_id_out_of_vocabulary(self):
with self.assertRaises(ValueError):
tokenizer.id_to_token(-1)

def test_special_tokens(self):
def test_special_tokens_string_dtype(self):
input_data = ["quick brown whale @MASK@"]
vocab_data = ["@UNK@", "qu", "@@ick", "br", "@@own", "fox", "@MASK@"]
special_tokens = ["@UNK@", "@MASK@"]
Expand All @@ -88,13 +88,28 @@ def test_special_tokens(self):
suffix_indicator="@@",
dtype="string",
special_tokens=special_tokens,
special_tokens_in_strings=True,
)
call_output = tokenizer(input_data)
self.assertAllEqual(
call_output,
[["qu", "@@ick", "br", "@@own", "@UNK@", "@MASK@"]],
)

def test_special_tokens_int_dtype(self):
input_data = ["[UNK] [MASK] [SEP] [PAD] [CLS] the quick brown fox."]
special_tokens = ["[UNK]", "[MASK]", "[SEP]", "[PAD]", "[CLS]"]
vocab_data = ["the", "qu", "##ick", "br", "##own", "fox", "."]
vocab_data = [*special_tokens, *vocab_data]

tokenizer = WordPieceTokenizer(
vocabulary=vocab_data,
special_tokens=special_tokens,
special_tokens_in_strings=True,
)
output = tokenizer(input_data)
self.assertAllEqual(output, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])

def test_cjk_tokens(self):
input_data = ["ah半推zz"]
vocab_data = ["[UNK]", "推", "敐", "乐", "半", "偷", "匕", "ah", "zz"]
Expand Down Expand Up @@ -220,18 +235,6 @@ def test_no_oov_token_in_vocabulary(self):
with self.assertRaises(ValueError):
WordPieceTokenizer(vocabulary=vocab_data, oov_token=None)

def test_tokenize_special_tokens(self):
input_data = ["[UNK] [MASK] [SEP] [PAD] [CLS] the quick brown fox."]
special_tokens = ["[UNK]", "[MASK]", "[SEP]", "[PAD]", "[CLS]"]
vocab_data = ["the", "qu", "##ick", "br", "##own", "fox", "."]
vocab_data = [*special_tokens, *vocab_data]

tokenizer = WordPieceTokenizer(
vocabulary=vocab_data, special_tokens=special_tokens
)
output = tokenizer(input_data)
self.assertAllEqual(output, [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])

def test_no_splitting_with_special_tokens(self):
# When `split` is `False`, no special tokens tokenization will be done.
input_data = [
Expand Down

0 comments on commit d927a86

Please sign in to comment.