-
Notifications
You must be signed in to change notification settings - Fork 242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add XLNetTokenizer
#1206
base: master
Are you sure you want to change the base?
Add XLNetTokenizer
#1206
Changes from 5 commits
b55691e
9ef8875
f1fe2d8
983d93d
e03e9bb
e431196
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,205 @@ | ||
# Copyright 2023 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""XLNET tokenizer.""" | ||
|
||
import tensorflow as tf | ||
|
||
from keras_nlp.api_export import keras_nlp_export | ||
from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer | ||
|
||
try: | ||
import unicodedata | ||
except ImportError: | ||
unicodedata = None | ||
|
||
|
||
@keras_nlp_export("keras_nlp.models.XLNetTokenizer") | ||
class XLNetTokenizer(SentencePieceTokenizer): | ||
"""XLNET tokenizer layer based on SentencePiece. | ||
|
||
This tokenizer class will tokenize raw strings into integer sequences and | ||
is based on `keras_nlp.tokenizers.SentencePieceTokenizer`. Unlike the | ||
underlying tokenizer, it will check for all special tokens needed by | ||
XLNET models and provides a `from_preset()` method to automatically | ||
download a matching vocabulary for a XLNET preset. | ||
|
||
This tokenizer does not provide truncation or padding of inputs. It can be | ||
combined with a `keras_nlp.models.XLNetPreprocessor` layer for input | ||
packing. | ||
|
||
If input is a batch of strings (rank > 0), the layer will output a | ||
`tf.RaggedTensor` where the last dimension of the output is ragged. | ||
|
||
If input is a scalar string (rank == 0), the layer will output a dense | ||
`tf.Tensor` with static shape `[None]`. | ||
|
||
Args: | ||
proto: Either a `string` path to a SentencePiece proto file, or a | ||
`bytes` object with a serialized SentencePiece proto. See the | ||
[SentencePiece repository](https://github.com/google/sentencepiece) | ||
for more details on the format. | ||
|
||
Examples: | ||
|
||
```python | ||
# Unbatched input. | ||
tokenizer = keras_nlp.models.XLNetTokenizer(proto="<path to SentencePiece proto file>") | ||
tokenizer("The quick brown fox jumped.") | ||
|
||
# Batched input. | ||
tokenizer(["The quick brown fox jumped.", "The fox slept."]) | ||
|
||
# Detokenization. | ||
tokenizer.detokenize(tokenizer("The quick brown fox jumped.")) | ||
|
||
# Custom vocabulary. | ||
bytes_io = io.BytesIO() | ||
ds = tf.data.Dataset.from_tensor_slices(["The quick brown fox jumped."]) | ||
sentencepiece.SentencePieceTrainer.train( | ||
sentence_iterator=vocab_data.as_numpy_iterator(), | ||
model_writer=bytes_io, | ||
vocab_size=14, | ||
model_type="WORD", | ||
pad_id=0, | ||
bos_id=1, | ||
eos_id=2, | ||
unk_id=3, | ||
pad_piece="<pad>", | ||
bos_piece="<s>", | ||
eos_piece="</s>", | ||
unk_piece="<unk>", | ||
user_defined_symbols=["<mask>", "<cls>", "<sep>"] | ||
) | ||
tokenizer = keras_nlp.models.XLNetTokenizer( | ||
proto=bytes_io.getvalue(), | ||
) | ||
tokenizer("The quick brown fox jumped.") | ||
``` | ||
""" | ||
|
||
def __init__(self, proto, **kwargs): | ||
super().__init__(proto=proto, **kwargs) | ||
|
||
# Check for necessary special tokens. | ||
self.cls_token = "<cls>" | ||
self.sep_token = "<sep>" | ||
self.pad_token = "<pad>" | ||
self.mask_token = "<mask>" | ||
self.bos_token = "<s>" | ||
self.eos_token = "</s>" | ||
self.unk_token = "<unk>" | ||
|
||
for token in [ | ||
self.cls_token, | ||
self.sep_token, | ||
self.pad_token, | ||
self.mask_token, | ||
self.bos_token, | ||
self.eos_token, | ||
self.unk_token, | ||
]: | ||
if token not in self.get_vocabulary(): | ||
raise ValueError( | ||
f"Cannot find token `'{token}'` in the provided " | ||
f"`vocabulary`. Please provide `'{token}'` in your " | ||
"`vocabulary` or use a pretrained `vocabulary` name." | ||
) | ||
|
||
self.cls_token_id = self.token_to_id(self.cls_token) | ||
self.sep_token_id = self.token_to_id(self.sep_token) | ||
self.pad_token_id = self.token_to_id(self.pad_token) | ||
self.mask_token_id = self.token_to_id(self.mask_token) | ||
self.bos_token_id = self.token_to_id(self.bos_token) | ||
self.eos_token_id = self.token_to_id(self.eos_token) | ||
self.unk_token_id = self.token_to_id(self.unk_token) | ||
|
||
def preprocess_text(self, inputs): | ||
"""Preprocesses the text. This method removes spaces and accents from the text.""" | ||
|
||
# remove space | ||
outputs = " ".join(inputs.strip().split()) | ||
outputs = outputs.replace("``", '"').replace("''", '"') | ||
|
||
# remove accents | ||
outputs = unicodedata.normalize("NFKD", outputs) | ||
outputs = "".join([c for c in outputs if not unicodedata.combining(c)]) | ||
|
||
return outputs | ||
|
||
def tokenize(self, text): | ||
"""Tokenize a string.""" | ||
|
||
# check if there are multiple examples present or not | ||
is_batched = isinstance(text, list) | ||
if not is_batched: | ||
text = [text] | ||
|
||
tokenized_text = [] | ||
for each_text in text: | ||
each_text = self.preprocess_text(each_text) | ||
pieces = [ | ||
self.id_to_token(token_id) | ||
for token_id in super().tokenize(each_text) | ||
] | ||
|
||
new_pieces = [] | ||
for piece in pieces: | ||
if ( | ||
len(piece) > 1 | ||
and piece[-1] == str(",") | ||
and piece[-2].isdigit() | ||
): | ||
cur_pieces = [ | ||
self.id_to_token(cur_piece_id) | ||
for cur_piece_id in super().tokenize( | ||
piece[:-1].replace("▁", "") | ||
) | ||
] | ||
if piece[0] != "▁" and cur_pieces[0][0] == "▁": | ||
if len(cur_pieces[0]) == 1: | ||
cur_pieces = cur_pieces[1:] | ||
else: | ||
cur_pieces[0] = cur_pieces[0][1:] | ||
cur_pieces.append(piece[-1]) | ||
new_pieces.extend(cur_pieces) | ||
else: | ||
new_pieces.append(piece) | ||
|
||
new_pieces = [ | ||
self.token_to_id(new_piece_token) | ||
for new_piece_token in new_pieces | ||
] | ||
# add sep_token and cls_token in the end. | ||
new_pieces.extend([self.sep_token_id, self.cls_token_id]) | ||
|
||
tokenized_text.append(new_pieces) | ||
|
||
# if there are multiple examples present, then return a `tf.RaggedTensor`. | ||
if is_batched: | ||
return tf.ragged.constant(tokenized_text) | ||
|
||
return tokenized_text[0] | ||
|
||
def detokenize(self, inputs): | ||
"""Detokenize the input_ids into text.""" | ||
|
||
outputs = super().detokenize(inputs) | ||
|
||
outputs = tf.strings.regex_replace(outputs, self.cls_token, "") | ||
outputs = tf.strings.regex_replace(outputs, self.sep_token, "") | ||
outputs = tf.strings.regex_replace(outputs, self.mask_token, "") | ||
outputs = tf.strings.regex_replace(outputs, self.pad_token, "") | ||
Comment on lines
+217
to
+220
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is one difference in For an example - from transformers import XLNetTokenizer
tokenizer_hf = XLNetTokenizer.from_pretrained("xlnet-base-cased")
text = "the quick brown fox"
print(tokenizer_hf.decode(tokenizer_hf(text)["input_ids"])) this will give us output -
the Please let me know if I should change this design to strictly follow the HF or not. |
||
|
||
return outputs |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
# Copyright 2023 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Tests for XLNET tokenizer.""" | ||
|
||
import io | ||
|
||
import sentencepiece | ||
import tensorflow as tf | ||
|
||
from keras_nlp.backend import keras | ||
from keras_nlp.models.xlnet.xlnet_tokenizer import XLNetTokenizer | ||
from keras_nlp.tests.test_case import TestCase | ||
|
||
|
||
class XLNetTokenizerTest(TestCase): | ||
def setUp(self): | ||
bytes_io = io.BytesIO() | ||
vocab_data = tf.data.Dataset.from_tensor_slices( | ||
["the quick brown fox", "the earth is round"] | ||
) | ||
sentencepiece.SentencePieceTrainer.train( | ||
sentence_iterator=vocab_data.as_numpy_iterator(), | ||
model_writer=bytes_io, | ||
vocab_size=14, | ||
model_type="WORD", | ||
pad_id=0, | ||
bos_id=1, | ||
eos_id=2, | ||
unk_id=3, | ||
pad_piece="<pad>", | ||
bos_piece="<s>", | ||
eos_piece="</s>", | ||
unk_piece="<unk>", | ||
user_defined_symbols=["<mask>", "<cls>", "<sep>"], | ||
) | ||
self.proto = bytes_io.getvalue() | ||
|
||
self.tokenizer = XLNetTokenizer(proto=self.proto) | ||
|
||
def test_tokenize(self): | ||
input_data = ["the quick brown fox"] | ||
output = self.tokenizer(input_data) | ||
self.assertAllEqual(output, [[7, 12, 8, 10, 6, 5]]) | ||
|
||
def test_tokenize_batch(self): | ||
input_data = ["the quick brown fox", "the earth is round"] | ||
output = self.tokenizer(input_data) | ||
self.assertAllEqual( | ||
output, [[7, 12, 8, 10, 6, 5], [7, 9, 11, 13, 6, 5]] | ||
) | ||
|
||
def test_detokenize(self): | ||
input_data = [[7, 12, 8, 10, 6, 5]] | ||
output = self.tokenizer.detokenize(input_data) | ||
self.assertEqual(output, ["the quick brown fox"]) | ||
|
||
def test_detokenize_mask_token(self): | ||
input_data = [[7, 12, 8, 10, 6, 5, self.tokenizer.mask_token_id]] | ||
output = self.tokenizer.detokenize(input_data) | ||
self.assertEqual(output, ["the quick brown fox"]) | ||
|
||
def test_vocabulary_size(self): | ||
self.assertEqual(self.tokenizer.vocabulary_size(), 14) | ||
|
||
def test_get_vocabulary_mask_token(self): | ||
self.assertEqual(self.tokenizer.get_vocabulary()[4], "<mask>") | ||
|
||
def test_id_to_token_mask_token(self): | ||
self.assertEqual(self.tokenizer.id_to_token(4), "<mask>") | ||
|
||
def test_token_to_id_mask_token(self): | ||
self.assertEqual(self.tokenizer.token_to_id("<mask>"), 4) | ||
|
||
def test_errors_missing_special_tokens(self): | ||
bytes_io = io.BytesIO() | ||
sentencepiece.SentencePieceTrainer.train( | ||
sentence_iterator=iter(["abc"]), | ||
model_writer=bytes_io, | ||
vocab_size=5, | ||
pad_id=-1, | ||
eos_id=-1, | ||
bos_id=-1, | ||
) | ||
with self.assertRaises(ValueError): | ||
XLNetTokenizer(proto=bytes_io.getvalue()) | ||
|
||
def test_serialization(self): | ||
config = keras.saving.serialize_keras_object(self.tokenizer) | ||
new_tokenizer = keras.saving.deserialize_keras_object(config) | ||
self.assertEqual( | ||
new_tokenizer.get_config(), | ||
self.tokenizer.get_config(), | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does not look like it would work with
tf.data
. A key feature for our tokenizers is to be able to runstring_ds.map(tokenizer)
, with atf.data.Dataset
, as this is really the only performant option for preprocessing we ship with the library.I would not worry about being one to one with huggingface w.r.t. string inputted special tokens right now, but we do need two things...
tokenize()
should chain tosuper()
and thetf.text
op for tokenizing text. No for loop tokenization.If we can get to that state we will be unblocked here. Why is there a need to diverge from the sentence piece routines below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @mattdangerw thanks for you comment! Yes the tokenizer is not working with the
tf.data.Dataset
.For (almost) all plain texts the
super().tokenize
is enough and produces the same upstream result but there are a few texts (such as"ABC 0.123,"
) where we must apply the extra logic to get the same result.[tokenizer.id_to_token(i) for i in tokenizer._sentence_piece.tokenize("ABC 0.123,")]
-['▁ABC', '▁0', '.', '12', '3,']
But the actual output is
['▁ABC', '▁0', '.', '12', '3', ',']
So, we must keep the extra logic in the tokenize. (The official repo also has the same logic)
My current plan is to replace all other
str
methods withtf text
and remove the outer loop.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explainer! Is it really just some weird workaround for digits followed by a comma?
Ideally we could figure out a way to either preprocess or postprocess the sentencepiece tokenize result so that we can still use the
tf-text
sentencepiece "black box" unaltered. Not sure if that is possible though...tensorflow-text
and thetf.strings
module will be a main tools here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes as of my understanding it's a workaround.