Skip to content

Commit

Permalink
Complete run_small_preset test for electra
Browse files Browse the repository at this point in the history
  • Loading branch information
pranavvp16 committed Mar 18, 2024
1 parent 4be8d50 commit dada198
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 15 deletions.
11 changes: 7 additions & 4 deletions keras_nlp/models/electra/electra_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_saved_model(self):
def test_smallest_preset(self):
self.run_preset_test(
cls=ElectraBackbone,
preset="electra-small-generator",
preset="electra_small_discriminator_en",
input_data={
"token_ids": ops.array([[101, 1996, 4248, 102]], dtype="int32"),
"segment_ids": ops.zeros((1, 4), dtype="int32"),
Expand All @@ -70,10 +70,13 @@ def test_smallest_preset(self):
"pooled_output": (1, 256),
},
# The forward pass from a preset should be stable!
# TODO: Add sequence and pooled output trimmed to 5 tokens.
expected_partial_output={
"sequence_output": (ops.array()),
"pooled_output": (ops.array()),
"sequence_output": (
ops.array([0.32287, 0.18754, -0.22272, -0.24177, 1.18977])
),
"pooled_output": (
ops.array([-0.02974, 0.23383, 0.08430, -0.19471, 0.14822])
),
},
)

Expand Down
4 changes: 2 additions & 2 deletions keras_nlp/models/electra/electra_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ class ElectraPreprocessor(Preprocessor):
2. Pack the inputs together using a `keras_nlp.layers.MultiSegmentPacker`.
with the appropriate `"[CLS]"`, `"[SEP]"` and `"[PAD]"` tokens.
3. Construct a dictionary of with keys `"token_ids"` and `"padding_mask"`,
that can be passed directly to a DistilBERT model.
that can be passed directly to a ELECTRA model.
This layer can be used directly with `tf.data.Dataset.map` to preprocess
string data in the `(x, y, sample_weight)` format used by
`keras.Model.fit`.
Args:
tokenizer: A `keras_nlp.models.DistilBertTokenizer` instance.
tokenizer: A `keras_nlp.models.ElectraTokenizer` instance.
sequence_length: The length of the packed inputs.
truncate: string. The algorithm to truncate a list of batched segments
to fit within `sequence_length`. The value can be either
Expand Down
16 changes: 8 additions & 8 deletions keras_nlp/models/electra/electra_presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
"ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators"
"This is base discriminator model with 12 layers."
),
"params": "109482240",
"params": 109482240,
"official_name": "ELECTRA",
"path": "electra",
"model_card": "https://huggingface.co/google/electra-base-discriminator",
"model_card": "https://github.com/google-research/electra",
},
"kaggle_handle": "kaggle://pranavprajapati16/electra/keras/electra_base_discriminator_en/1",
},
Expand All @@ -33,10 +33,10 @@
"ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators"
"This is small discriminator model with 12 layers."
),
"params": "13,548,800",
"params": 13548800,
"official_name": "ELECTRA",
"path": "electra",
"model_card": "https://huggingface.co/google/electra-small-discriminator",
"model_card": "https://github.com/google-research/electra",
},
"kaggle_handle": "kaggle://pranavprajapati16/electra/keras/electra_small_discriminator_en/1",
},
Expand All @@ -46,10 +46,10 @@
"ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators"
"This is small generator model with 12 layers."
),
"params": "13548800",
"params": 13548800,
"official_name": "ELECTRA",
"path": "electra",
"model_card": "https://huggingface.co/google/electra-small-generator",
"model_card": "https://github.com/google-research/electra",
},
"kaggle_handle": "kaggle://pranavprajapati16/electra/keras/electra_small_generator_en/1",
},
Expand All @@ -59,10 +59,10 @@
"ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators"
"This is base generator model with 12 layers."
),
"params": "33576960",
"params": 33576960,
"official_name": "ELECTRA",
"path": "electra",
"model_card": "https://huggingface.co/google/electra-base-generator",
"model_card": "https://github.com/google-research/electra",
},
"kaggle_handle": "kaggle://pranavprajapati16/electra/keras/electra_base_generator_en/1",
},
Expand Down
2 changes: 1 addition & 1 deletion keras_nlp/models/electra/electra_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_errors_missing_special_tokens(self):
def test_smallest_preset(self):
self.run_preset_test(
cls=ElectraTokenizer,
preset="distil_bert_base_en_uncased",
preset="electra_base_discriminator_en",
input_data=["The quick brown fox."],
expected_output=[[1996, 4248, 2829, 4419, 1012]],
)
Expand Down

0 comments on commit dada198

Please sign in to comment.