Skip to content

Commit

Permalink
style fix
Browse files Browse the repository at this point in the history
  • Loading branch information
susnato committed Aug 10, 2023
1 parent f1fe2d8 commit 983d93d
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 37 deletions.
98 changes: 66 additions & 32 deletions keras_nlp/models/xlnet/xlnet_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class XLNetTokenizer(SentencePieceTokenizer):
is based on `keras_nlp.tokenizers.SentencePieceTokenizer`. Unlike the
underlying tokenizer, it will check for all special tokens needed by
XLNET models and provides a `from_preset()` method to automatically
download a matching vocabulary for a ALBERT preset.
download a matching vocabulary for a XLNET preset.
This tokenizer does not provide truncation or padding of inputs. It can be
combined with a `keras_nlp.models.XLNetPreprocessor` layer for input
Expand Down Expand Up @@ -68,22 +68,21 @@ class XLNetTokenizer(SentencePieceTokenizer):
bytes_io = io.BytesIO()
ds = tf.data.Dataset.from_tensor_slices(["The quick brown fox jumped."])
sentencepiece.SentencePieceTrainer.train(
sentence_iterator=ds.as_numpy_iterator(),
sentence_iterator=vocab_data.as_numpy_iterator(),
model_writer=bytes_io,
vocab_size=10,
vocab_size=14,
model_type="WORD",
pad_id=0,
unk_id=1,
bos_id=2,
eos_id=3,
bos_id=1,
eos_id=2,
unk_id=3,
pad_piece="<pad>",
sep_piece="<sep>",
unk_piece="<unk>",
bos_piece="<s>",
eos_piece="</s>",
user_defined_symbols="[MASK]",
unk_piece="<unk>",
user_defined_symbols=["<mask>", "<cls>", "<sep>"]
)
tokenizer = keras_nlp.models.AlbertTokenizer(
tokenizer = keras_nlp.models.XLNetTokenizer(
proto=bytes_io.getvalue(),
)
tokenizer("The quick brown fox jumped.")
Expand All @@ -94,32 +93,40 @@ def __init__(self, proto, **kwargs):
super().__init__(proto=proto, **kwargs)

# Check for necessary special tokens.
cls_token = "<cls>"
sep_token = "<sep>"
pad_token = "<pad>"
mask_token = "<mask>"
bos_token = "<s>"
eos_token = "</s>"
unk_token = "<unk>"

for token in [cls_token, sep_token, pad_token, mask_token, bos_token, eos_token, unk_token]:
self.cls_token = "<cls>"
self.sep_token = "<sep>"
self.pad_token = "<pad>"
self.mask_token = "<mask>"
self.bos_token = "<s>"
self.eos_token = "</s>"
self.unk_token = "<unk>"

for token in [
self.cls_token,
self.sep_token,
self.pad_token,
self.mask_token,
self.bos_token,
self.eos_token,
self.unk_token,
]:
if token not in self.get_vocabulary():
raise ValueError(
f"Cannot find token `'{token}'` in the provided "
f"`vocabulary`. Please provide `'{token}'` in your "
"`vocabulary` or use a pretrained `vocabulary` name."
)

self.cls_token_id = self.token_to_id(cls_token)
self.sep_token_id = self.token_to_id(sep_token)
self.pad_token_id = self.token_to_id(pad_token)
self.mask_token_id = self.token_to_id(mask_token)
self.bos_token_id = self.token_to_id(bos_token)
self.eos_token_id = self.token_to_id(eos_token)
self.unk_token_id = self.token_to_id(unk_token)
self.cls_token_id = self.token_to_id(self.cls_token)
self.sep_token_id = self.token_to_id(self.sep_token)
self.pad_token_id = self.token_to_id(self.pad_token)
self.mask_token_id = self.token_to_id(self.mask_token)
self.bos_token_id = self.token_to_id(self.bos_token)
self.eos_token_id = self.token_to_id(self.eos_token)
self.unk_token_id = self.token_to_id(self.unk_token)

def preprocess_text(self, inputs):
"""Preprocesses the text. This method removes spaces and accents."""
"""Preprocesses the text. This method removes spaces and accents from the text."""

# remove space
outputs = " ".join(inputs.strip().split())
Expand All @@ -134,20 +141,32 @@ def preprocess_text(self, inputs):
def tokenize(self, text):
"""Tokenize a string."""

# check if there are multiple batches present or not
# check if there are multiple examples present or not
is_batched = isinstance(text, list)
if not is_batched:
text = [text]

tokenized_text = []
for each_text in text:
each_text = self.preprocess_text(each_text)
pieces = [self.id_to_token(token_id) for token_id in super().tokenize(each_text)]
pieces = [
self.id_to_token(token_id)
for token_id in super().tokenize(each_text)
]

new_pieces = []
for piece in pieces:
if len(piece) > 1 and piece[-1] == str(",") and piece[-2].isdigit():
cur_pieces = [self.id_to_token(cur_piece_id) for cur_piece_id in super().tokenize(piece[:-1].replace("▁", ""))]
if (
len(piece) > 1
and piece[-1] == str(",")
and piece[-2].isdigit()
):
cur_pieces = [
self.id_to_token(cur_piece_id)
for cur_piece_id in super().tokenize(
piece[:-1].replace("▁", "")
)
]
if piece[0] != "▁" and cur_pieces[0][0] == "▁":
if len(cur_pieces[0]) == 1:
cur_pieces = cur_pieces[1:]
Expand All @@ -158,14 +177,29 @@ def tokenize(self, text):
else:
new_pieces.append(piece)

new_pieces = [self.token_to_id(new_piece_token) for new_piece_token in new_pieces]
new_pieces = [
self.token_to_id(new_piece_token)
for new_piece_token in new_pieces
]
# add sep_token and cls_token in the end.
new_pieces.extend([self.sep_token_id, self.cls_token_id])

tokenized_text.append(new_pieces)

# if there are multiple examples present, then return a `tf.RaggedTensor`.
if is_batched:
return tf.ragged.constant(tokenized_text)

return tokenized_text[0]

def detokenize(self, inputs):
"""Detokenize the input_ids into text."""

outputs = super().detokenize(inputs)

outputs = tf.strings.regex_replace(outputs, self.cls_token, "")
outputs = tf.strings.regex_replace(outputs, self.sep_token, "")
outputs = tf.strings.regex_replace(outputs, self.mask_token, "")
outputs = tf.strings.regex_replace(outputs, self.pad_token, "")

return outputs
12 changes: 7 additions & 5 deletions keras_nlp/models/xlnet/xlnet_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Tests for XLNET tokenizer."""

import io

import sentencepiece
import tensorflow as tf

Expand Down Expand Up @@ -42,12 +43,11 @@ def setUp(self):
bos_piece="<s>",
eos_piece="</s>",
unk_piece="<unk>",
user_defined_symbols=["<mask>", "<cls>", "<sep>"]
user_defined_symbols=["<mask>", "<cls>", "<sep>"],
)
self.proto = bytes_io.getvalue()

self.tokenizer = XLNetTokenizer(proto=self.proto)
print(self.tokenizer.get_vocabulary())

def test_tokenize(self):
input_data = "the quick brown fox"
Expand All @@ -57,17 +57,19 @@ def test_tokenize(self):
def test_tokenize_batch(self):
input_data = ["the quick brown fox", "the earth is round"]
output = self.tokenizer(input_data)
self.assertAllEqual(output, [[7, 12, 8, 10, 6, 5], [7, 9, 11, 13, 6, 5]])
self.assertAllEqual(
output, [[7, 12, 8, 10, 6, 5], [7, 9, 11, 13, 6, 5]]
)

def test_detokenize(self):
input_data = [[7, 12, 8, 10, 6, 5]]
output = self.tokenizer.detokenize(input_data)
self.assertEqual(output, ["the quick brown fox<sep><cls>"])
self.assertEqual(output, ["the quick brown fox"])

def test_detokenize_mask_token(self):
input_data = [[7, 12, 8, 10, 6, 5, self.tokenizer.mask_token_id]]
output = self.tokenizer.detokenize(input_data)
self.assertEqual(output, ["the quick brown fox<sep><cls><mask>"])
self.assertEqual(output, ["the quick brown fox"])

def test_vocabulary_size(self):
self.assertEqual(self.tokenizer.vocabulary_size(), 14)
Expand Down

0 comments on commit 983d93d

Please sign in to comment.