-
Notifications
You must be signed in to change notification settings - Fork 242
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update MaskedLMHead to support dtype=bfloat16/float16/float64 #1197
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -153,9 +153,11 @@ def build(self, inputs_shape, masked_positions_shape=None): | |
activation=self.intermediate_activation, | ||
kernel_initializer=self.kernel_initializer, | ||
bias_initializer=self.bias_initializer, | ||
dtype=self._dtype_policy, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like |
||
) | ||
self._layer_norm = keras.layers.LayerNormalization( | ||
epsilon=self.layer_norm_epsilon, | ||
dtype=self._dtype_policy, | ||
) | ||
if masked_positions_shape: | ||
gather_length = masked_positions_shape[1] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,9 @@ | |
|
||
import os | ||
|
||
import tensorflow as tf | ||
from absl.testing import parameterized | ||
|
||
from keras_nlp.backend import keras | ||
from keras_nlp.backend import ops | ||
from keras_nlp.layers.modeling import masked_lm_head | ||
|
@@ -36,6 +39,30 @@ def test_valid_call(self): | |
position_data = ops.random.randint(minval=0, maxval=10, shape=(4, 5)) | ||
model((token_data, position_data)) | ||
|
||
@parameterized.named_parameters( | ||
("bfloat16", tf.bfloat16), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. because we now run our testing suite with jax/torch/tf with keras-core, we are generally just referring to these by string name, e.g. Does anything break if we switch to that? |
||
("float16", tf.float16), | ||
("float32", tf.float32), | ||
("float64", tf.float64), | ||
) | ||
def test_valid_call_with_dtype(self, dtype): | ||
head = masked_lm_head.MaskedLMHead( | ||
vocabulary_size=100, | ||
activation="softmax", | ||
dtype=dtype, | ||
) | ||
encoded_tokens = keras.Input(shape=(10, 16)) | ||
positions = keras.Input(shape=(5,), dtype="int32") | ||
outputs = head(encoded_tokens, masked_positions=positions) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might need a rebase over master. This should be |
||
model = keras.Model((encoded_tokens, positions), outputs) | ||
|
||
token_data = ops.random.uniform(shape=(4, 10, 16)) | ||
position_data = ops.random.randint(minval=0, maxval=10, shape=(4, 5)) | ||
model((token_data, position_data)) | ||
|
||
for w in head.weights: | ||
self.assertEqual(w.dtype, dtype, "Wrong type: " + w.name) | ||
|
||
def test_valid_call_with_embedding_weights(self): | ||
embedding = keras.layers.Embedding(100, 16) | ||
embedding.build((4, 10)) | ||
|
@@ -119,6 +146,32 @@ def test_one_train_step(self): | |
loss = model.train_on_batch(x=(token_data, position_data), y=label_data) | ||
self.assertGreater(loss, 0) | ||
|
||
@parameterized.named_parameters( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would kill this test. Compiling a real loss function can make for slower tests, and with the parameterized testing this could slow down our suite. |
||
("bfloat16", tf.bfloat16), | ||
("float16", tf.float16), | ||
("float32", tf.float32), | ||
("float64", tf.float64), | ||
) | ||
def test_one_train_step_with_dtype(self, dtype): | ||
head = masked_lm_head.MaskedLMHead( | ||
vocabulary_size=100, | ||
dtype=dtype, | ||
) | ||
encoded_tokens = keras.Input(shape=(10, 16)) | ||
positions = keras.Input(shape=(5,), dtype="int32") | ||
outputs = head(encoded_tokens, masked_positions=positions) | ||
model = keras.Model((encoded_tokens, positions), outputs) | ||
|
||
token_data = ops.random.uniform(shape=(4, 10, 16)) | ||
position_data = ops.random.randint(minval=0, maxval=10, shape=(4, 5)) | ||
label_data = ops.random.randint(minval=0, maxval=2, shape=(4, 5, 1)) | ||
|
||
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=False) | ||
optimizer = keras.optimizers.Adam() | ||
model.compile(loss=loss, optimizer=optimizer) | ||
loss = model.train_on_batch(x=(token_data, position_data), y=label_data) | ||
self.assertGreater(loss, 0) | ||
|
||
def test_saved_model(self): | ||
head = masked_lm_head.MaskedLMHead( | ||
vocabulary_size=100, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like we should really have this for all our our "composite" layers in KerasNLP, right?
Are you interested in following up for other layers? (same PR or split PRs fine!)