Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add XLNetTokenizer #1206

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open

Conversation

susnato
Copy link
Contributor

@susnato susnato commented Aug 9, 2023

Adds XLNET Tokenizer to the library. The part 2 of adding the xlnet to keras-nlp.

@susnato susnato changed the title [WIP] Add XLNetTokenizer Add XLNetTokenizer Aug 10, 2023
@susnato susnato marked this pull request as ready for review August 10, 2023 09:47
@susnato
Copy link
Contributor Author

susnato commented Aug 10, 2023

This PR is ready for review.

cc : @mattdangerw

Comment on lines +200 to +203
outputs = tf.strings.regex_replace(outputs, self.cls_token, "")
outputs = tf.strings.regex_replace(outputs, self.sep_token, "")
outputs = tf.strings.regex_replace(outputs, self.mask_token, "")
outputs = tf.strings.regex_replace(outputs, self.pad_token, "")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is one difference in detokenize where contrary to the HuggingFace implementation, keras_nlp tokenizer removes the <cls> or <sep> tokens at the end,

For an example -

from transformers import XLNetTokenizer

tokenizer_hf = XLNetTokenizer.from_pretrained("xlnet-base-cased")
text = "the quick brown fox"
print(tokenizer_hf.decode(tokenizer_hf(text)["input_ids"]))

this will give us output -

the quick brown fox<sep><cls>

the keras-nlp implementation will remove those and tokens and give us the quick brown fox.

Please let me know if I should change this design to strictly follow the HF or not.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! One meta comment.


return outputs

def tokenize(self, text):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does not look like it would work with tf.data. A key feature for our tokenizers is to be able to run string_ds.map(tokenizer), with a tf.data.Dataset, as this is really the only performant option for preprocessing we ship with the library.

I would not worry about being one to one with huggingface w.r.t. string inputted special tokens right now, but we do need two things...

  • Plain text (ignore special tokens in both input and output), should tokenize exactly the same as the upstream implementation.
  • tokenize() should chain to super() and the tf.text op for tokenizing text. No for loop tokenization.

If we can get to that state we will be unblocked here. Why is there a need to diverge from the sentence piece routines below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @mattdangerw thanks for you comment! Yes the tokenizer is not working with the tf.data.Dataset.

Plain text (ignore special tokens in both input and output), should tokenize exactly the same as the upstream implementation.

For (almost) all plain texts the super().tokenize is enough and produces the same upstream result but there are a few texts (such as "ABC 0.123,") where we must apply the extra logic to get the same result.

  • Output of [tokenizer.id_to_token(i) for i in tokenizer._sentence_piece.tokenize("ABC 0.123,")] -
    ['▁ABC', '▁0', '.', '12', '3,']

But the actual output is ['▁ABC', '▁0', '.', '12', '3', ',']
So, we must keep the extra logic in the tokenize. (The official repo also has the same logic)

My current plan is to replace all other str methods with tf text and remove the outer loop.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explainer! Is it really just some weird workaround for digits followed by a comma?

Ideally we could figure out a way to either preprocess or postprocess the sentencepiece tokenize result so that we can still use the tf-text sentencepiece "black box" unaltered. Not sure if that is possible though...

tensorflow-text and the tf.strings module will be a main tools here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes as of my understanding it's a workaround.

@susnato
Copy link
Contributor Author

susnato commented Aug 12, 2023

Hi @mattdangerw, I have made some changes in this commit -

  • all string related operations are now replaced with their respective tensorflow-text and the tf.strings module.
  • added a new postprocess method which takes care of all the complex logic. Since we must keep those inner loops as I have mentioned here, I used tf.py_function to wrap the python function. Now the tokenizer is working with the ds.map(tokenizer)!
  • I have also added a test test_tokenize_ds which checks if the tokenizer is working with the tf.data.Dataset or not.

I believe in this way we can have the workaround for digits followed by a comma along with the possibility of using the tokenizer with tf.data.Dataset.

Please review it and let me know if the code style compiles with the library or not.

@mattdangerw
Copy link
Member

Will need to step through this more carefully soon, but at first blush this looks like it would probably be quite inefficient. py_function is quite possible going to be slower then doing this all in python, because of the frequent conversion of types, and dipping in an out of the python interpreter. There is also a lot of going from string -> int -> string.

We should probably think about this a little more.

What happens if you just preprocess any 123, -> 123 ,?

@susnato
Copy link
Contributor Author

susnato commented Aug 16, 2023

What happens if you just preprocess any 123, -> 123 ,?

If we don't use the logic and only use tokenizer._sentence_piece.tokenize, we will get ['▁12', '3,'].
The real output is ['▁12', '3', ',']

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants