-
Notifications
You must be signed in to change notification settings - Fork 242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added an Example for BLEU. #806
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ | |
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
|
@@ -63,8 +63,8 @@ class Bleu(keras.metrics.Metric): | |
n-grams so as to not give a high score to a (reference, prediction) pair | ||
with redundant, repeated tokens. Secondly, BLEU score tends to reward | ||
shorter predictions more, which is why a brevity penalty is applied to | ||
penalise short predictions. For more details, see the following article: | ||
https://cloud.google.com/translate/automl/docs/evaluate#bleu. | ||
penalise short predictions. For more details, see | ||
[this article](https://cloud.google.com/translate/automl/docs/evaluate#bleu). | ||
|
||
Note on input shapes: | ||
For unbatched inputs, `y_pred` should be a tensor of shape `()`, and | ||
|
@@ -80,8 +80,8 @@ class Bleu(keras.metrics.Metric): | |
(of any shape), and tokenizes the strings in the tensor. If the | ||
tokenizer is not specified, the default tokenizer is used. The | ||
default tokenizer replicates the behaviour of SacreBLEU's | ||
`"tokenizer_13a"` tokenizer | ||
(https://github.com/mjpost/sacrebleu/blob/v2.1.0/sacrebleu/tokenizers/tokenizer_13a.py). | ||
`"tokenizer_13a"` tokenizer, see | ||
[SacreBLEU's `"tokenizer_13a"` tokenizer](https://github.com/mjpost/sacrebleu/blob/v2.1.0/sacrebleu/tokenizers/tokenizer_13a.py). | ||
max_order: int. The maximum n-gram order to use. For example, if | ||
`max_order` is set to 3, unigrams, bigrams, and trigrams will be | ||
considered. Defaults to 4. | ||
|
@@ -92,7 +92,86 @@ class Bleu(keras.metrics.Metric): | |
dtype: string or tf.dtypes.Dtype. Precision of metric computation. If | ||
not specified, it defaults to tf.float32. | ||
name: string. Name of the metric instance. | ||
**kwargs: Other keyword arguments. | ||
|
||
Examples: | ||
|
||
1. Various Input Types. | ||
1.1. Python string. | ||
>>> bleu = keras_nlp.metrics.Bleu(max_order=4) | ||
>>> ref_sentence = "the quick brown fox jumps over the lazy dog" | ||
>>> pred_sentence = "the quick brown fox jumps over the box" | ||
>>> score = bleu([ref_sentence], [pred_sentence]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we mention that this example is with Python string inputs, this should ideally be
|
||
<tf.Tensor(0.7420885, shape=(), dtype=float32)> | ||
|
||
1.2. rank 1 inputs. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should change these to
or something like that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I checked the Rouge L metric documentation, where this same convention is used. Is it okay to make these changes? |
||
a. Python list. | ||
>>> bleu = keras_nlp.metrics.Bleu(max_order=4) | ||
>>> ref_sentence = [ | ||
"the quick brown fox jumps over the lazy dog", | ||
"the quick brown fox jumps over the lazy frog" | ||
] | ||
>>> pred_sentence = ["the quick brown fox jumps over the box"] | ||
>>> score = bleu(ref_sentence, pred_sentence) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of assigning the score to a variable, let's just have
Please do the same in other places as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although the above example works, the ideal input formula is:
I think we should probably stick to this formula in our examples. Otherwise, it becomes confusing for the reader. Could you please make this change everywhere? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Correct me if I am wrong, for this, do I need to change blue(ref_sentence, pred_sentence) to bleu([ref_sentence], pred_sentence)? |
||
<tf.Tensor(0.7420885, shape=(), dtype=float32)> | ||
|
||
b. Tensor. | ||
>>> bleu = keras_nlp.metrics.Bleu(max_order=4) | ||
>>> ref_sentence = tf.constant([ | ||
"the quick brown fox jumps over the lazy dog", | ||
"the quick brown fox jumps over the lazy frog" | ||
]) | ||
>>> pred_sentence = tf.constant(["the quick brown fox jumps over the box"]) | ||
>>> score = bleu(ref_sentence, pred_sentence) | ||
<tf.Tensor(0.7420885, shape=(), dtype=float32)> | ||
|
||
c. RaggedTensor. | ||
>>> bleu = keras_nlp.metrics.Bleu(max_order=4) | ||
>>> ref_sentence = tf.ragged.constant([ | ||
[ | ||
"the quick brown fox jumps over the lazy dog", | ||
"the quick brown fox jumps over the lazy frog" | ||
] | ||
]) | ||
>>> pred_sentence = tf.ragged.constant([ | ||
["the quick brown fox jumps over the box"] | ||
]) | ||
>>> score = bleu(ref_sentence, pred_sentence) | ||
<tf.Tensor(0.7420885, shape=(), dtype=float32)> | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we mention that we are using ragged tensors here, it would be great if we could actually have a tensor with variable dimension along an axis. Also, I think your example will throw an error; So, something like:
|
||
|
||
1.3. rank 2 inputs. | ||
a. Python list. | ||
>>> bleu = keras_nlp.metrics.Bleu(max_order=4) | ||
>>> ref_sentence = [ | ||
["the quick brown fox jumps over the lazy dog", "the quick brown fox jumps over the lazy frog"], | ||
["the quick brown fox jumps over the lazy dog", "the quick brown fox jumps over the lazy frog"] | ||
] | ||
>>> pred_sentence = [ | ||
["the quick brown fox jumps over the box"], | ||
["the quick brown fox jumps over the box"] | ||
] | ||
>>> score = bleu(ref_sentence, pred_sentence) | ||
<tf.Tensor(0.7420885, shape=(), dtype=float32)> | ||
|
||
2. Passing a custom tokenizer. | ||
>>> vocab = ["[UNK]", "the", "qu", "##ick", "br", "##own", "fox", "dog", "."] | ||
>>> tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(vocabulary = vocab, lowercase=True) | ||
>>> bleu = keras_nlp.metrics.Bleu(max_order=4, tokenizer=tokenizer) | ||
>>> ref_sentence = "the quick brown fox" | ||
>>> pred_sentence = "the quick brown dog" | ||
>>> score = bleu([ref_sentence], [pred_sentence]) | ||
<tf.Tensor(0.75983566, shape=(), dtype=float32)> | ||
|
||
3. Pass the metric to `model.compile()`. | ||
>>> inputs = keras.Input(shape=(), dtype='string') | ||
>>> outputs = tf.strings.lower(inputs) | ||
>>> model = keras.Model(inputs, outputs) | ||
>>> model.compile(metrics=[keras_nlp.metrics.Bleu()]) | ||
>>> ref_sentence = tf.constant(["the quick brown fox jumps over the lazy dog"]) | ||
>>> pred_sentence = tf.constant(["the quick brown fox jumps over the box"]) | ||
>>> metric_dict = model.evaluate(ref_sentence, pred_sentence, return_dict=True) | ||
>>> metric_dict['bleu'] | ||
0.7259795069694519 | ||
|
||
|
||
References: | ||
- [Papineni et al., 2002](https://aclanthology.org/P02-1040/) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove extra spaces.