Skip to content

Commit

Permalink
Add special_tokens_in_strings to byte_pair_tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
abuelnasr0 committed Apr 2, 2024
1 parent 29873a9 commit d0ff826
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 40 deletions.
84 changes: 48 additions & 36 deletions keras_nlp/tokenizers/byte_pair_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,10 @@
SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$"""


def get_unsplittable_tokens_pattern(unsplittable_tokens):
if unsplittable_tokens is None or len(unsplittable_tokens) == 0:
def get_special_tokens_pattern(special_tokens):
if special_tokens is None or len(special_tokens) == 0:
return None
return r"|".join([re.escape(token) for token in unsplittable_tokens])
return r"|".join([re.escape(token) for token in special_tokens])


def bytes_to_unicode():
Expand Down Expand Up @@ -97,7 +97,7 @@ def remove_strings_from_inputs(tensor, string_to_remove):
return result


def split_strings_for_bpe(inputs, unsplittable_tokens_pattern=None):
def split_strings_for_bpe(inputs, special_tokens_pattern=None):
# We need to recreate the exact behavior of token presplitting in the
# original gpt2 tokenizer which uses a lookahead. As re2 does not
# support lookahead match, we are using an alternative insert a special
Expand All @@ -110,26 +110,23 @@ def split_strings_for_bpe(inputs, unsplittable_tokens_pattern=None):
inputs, rf"(\s{SPECIAL_WHITESPACES})$", r"\1६"
)

if unsplittable_tokens_pattern is not None:
# First split the unsplittable tokens from the input.
if special_tokens_pattern is not None:
# First split the special tokens from the input.
raw_tokens = tf_text.regex_split(
inputs, unsplittable_tokens_pattern, unsplittable_tokens_pattern
inputs, special_tokens_pattern, special_tokens_pattern
)
split_pattern_1_with_unsplittable_tokens = r"|".join(
[unsplittable_tokens_pattern, SPLIT_PATTERN_1]
)
# Then split using both `unsplittable_tokens_pattern` and
# Then split using both `special_tokens_pattern` and
# `SPLIT_PATTERN_1` to split inputs like original gpt2, while not
# affecting the unsplittable tokens.
# We split unsplittable tokens first then apply this split instead of
# affecting the special tokens.
# We split special tokens first then apply this split instead of
# applying this split directly, because otherwise we will not split
# unsplittable tokens from inputs properly, because of this pattern
# special tokens from inputs properly, because of this pattern
# ` ?[^\s\p{L}\p{N}{special_spaces}]+`.
# e.g., [" </s>"] will be [" </", "s", ">"] instead of [" ", "</s>"]
raw_tokens = tf_text.regex_split(
raw_tokens,
split_pattern_1_with_unsplittable_tokens,
split_pattern_1_with_unsplittable_tokens,
r"|".join([special_tokens_pattern, SPLIT_PATTERN_1]),
r"|".join([special_tokens_pattern, SPLIT_PATTERN_1]),
)
raw_tokens = raw_tokens.merge_dims(-2, -1)
else:
Expand Down Expand Up @@ -241,12 +238,17 @@ class BytePairTokenizer(tokenizer.Tokenizer):
a prefix space to the first word will cause it to be tokenized
equivalently to all subsequent words in the sequence.
Defaults to `False`.
unsplittable_tokens: list. A list of strings that will
never be split during the word-level splitting applied before the
byte-pair encoding. This can be used to ensure special tokens map to
unique indices in the vocabulary, even if these special tokens
contain splittable characters such as punctuation. Special tokens
must still be included in `vocabulary`. Defaults to `None`.
special_tokens: list. A list of special tokens. when
`special_tokens_in_strings` is set to `True`, special
tokens will never be split during the word-level splitting applied
before the byte-pair encoding. This can be used to ensure special
tokens map to unique indices in the vocabulary, even if these
special tokens contain splittable characters such as
punctuation. special tokens must still be included in
`vocabulary`. Defaults to `None`.
special_tokens_in_strings: bool. To indicate if the tokenizer
should expect special tokens in input strings that should be
tokenized and mapped correctly to their ids. Defaults to False.
Examples:
Expand Down Expand Up @@ -285,7 +287,8 @@ def __init__(
merges=None,
sequence_length=None,
add_prefix_space=False,
unsplittable_tokens=None,
special_tokens=None,
special_tokens_in_strings=False,
dtype="int32",
**kwargs,
) -> None:
Expand All @@ -300,10 +303,12 @@ def __init__(
super().__init__(dtype=dtype, **kwargs)
self.sequence_length = sequence_length
self.add_prefix_space = add_prefix_space
self.unsplittable_tokens = unsplittable_tokens
self._unsplittable_tokens_pattern = get_unsplittable_tokens_pattern(
unsplittable_tokens
)
self.special_tokens = special_tokens
self._special_tokens_pattern = None
if special_tokens_in_strings:
self._special_tokens_pattern = get_special_tokens_pattern(
special_tokens
)

# Create byte <=> unicode mapping. This is useful for handling
# whitespace tokens.
Expand Down Expand Up @@ -355,6 +360,17 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
"token to int ids. Received: "
f"`type(vocabulary)={type(vocabulary)}`."
)

# Check for special tokens in vocabulary.
if self.special_tokens is not None:
for token in self.special_tokens:
if token not in self.get_vocabulary():
raise ValueError(
f"Cannot find token `'{token}'` in the provided "
f"`vocabulary`. Please provide `'{token}'` in your"
"`vocabulary` or use a pretrained `vocabulary` name."
)

if isinstance(merges, str):
with open(merges, encoding="utf-8") as f:
self.merges = [bp.rstrip() for bp in f]
Expand All @@ -367,12 +383,10 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
)

self.cache = BytePairTokenizerCache()
if self.unsplittable_tokens:
if self.special_tokens and self._special_tokens_pattern is not None:
# Put special tokens into cache, so it won't be further split and
# merged.
self.cache.insert(
self.unsplittable_tokens, self.unsplittable_tokens
)
self.cache.insert(self.special_tokens, self.special_tokens)

# Create mapping between string tokens to int ids, and vice versa.
byte_pairs = [x[0] for x in self.vocabulary.items()]
Expand Down Expand Up @@ -550,9 +564,7 @@ def tokenize(self, inputs):
if scalar_input:
inputs = tf.expand_dims(inputs, 0)

raw_tokens = split_strings_for_bpe(
inputs, self._unsplittable_tokens_pattern
)
raw_tokens = split_strings_for_bpe(inputs, self._special_tokens_pattern)
token_row_splits = raw_tokens.row_splits
flat_tokens = raw_tokens.flat_values

Expand Down Expand Up @@ -646,7 +658,7 @@ def get_config(self):
{
"sequence_length": self.sequence_length,
"add_prefix_space": self.add_prefix_space,
"unsplittable_tokens": self.unsplittable_tokens,
"special_tokens": self.special_tokens,
}
)
return config
return config
18 changes: 14 additions & 4 deletions keras_nlp/tokenizers/byte_pair_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,30 +67,40 @@ def test_tokenize_with_special_tokens(self):
tokenizer = BytePairTokenizer(
vocabulary=vocab,
merges=merges,
unsplittable_tokens=["s", "p"],
special_tokens=["s", "p"],
special_tokens_in_strings=True,
)
output = tokenizer("sp")
self.assertAllEqual(output, [1, 2])

# If not setting special tokens, "sp" is one token.
# If not special_tokens_in_strings is `True`, "sp" is one token.
tokenizer = BytePairTokenizer(
vocabulary=vocab,
merges=merges,
special_tokens=["s", "p"],
)
output = tokenizer("sp")
self.assertAllEqual(output, [0])

# test real wolrd special tokens. e. g. <s> and </s>
vocab = {"<s>": 0, "</s>": 1, "a": 2, "Ġquick": 3, "Ġfox": 4}
merges = ["Ġ q", "u i", "c k", "ui ck", "Ġq uick"]
merges += ["Ġ f", "o x", "Ġf ox"]
tokenizer = BytePairTokenizer(
vocabulary=vocab,
merges=merges,
unsplittable_tokens=["<s>", "</s>"],
special_tokens=["<s>", "</s>"],
special_tokens_in_strings=True,
)
output = tokenizer("<s>a quick fox</s>")
self.assertAllEqual(output, [0, 2, 3, 4, 1])

def test_errors_missing_special_tokens(self):
with self.assertRaises(ValueError):
BytePairTokenizer(
vocabulary=["a", "b", "c"], merges=[], special_tokens=["d"]
)

def test_tokenize_prefix_space(self):
input_data = ["brown.", "black."]
tokenizer = BytePairTokenizer(
Expand Down Expand Up @@ -181,4 +191,4 @@ def test_config(self):
self.assertAllEqual(
self.tokenizer(input_data),
cloned_tokenizer(input_data),
)
)

0 comments on commit d0ff826

Please sign in to comment.