From bb423c87a0ce35bb6e8f2f07ca3c48e0fead50b0 Mon Sep 17 00:00:00 2001 From: james77777778 <20734616+james77777778@users.noreply.github.com> Date: Thu, 4 Jul 2024 07:40:10 +0800 Subject: [PATCH] Add quantization support for `Gemma`, `Gemma2` and `PaliGemma` (#1670) * Introduce quantization support to `Gemma` * Revert `SentencePieceTokenizer` * Add tests for `PaliGemma` * Add quantization support for Gemma2 * Address comments --- .../layers/modeling/reversible_embedding.py | 129 ++++++++++++++++-- .../modeling/reversible_embedding_test.py | 68 +++++++++ keras_nlp/src/models/backbone.py | 17 ++- .../src/models/bloom/bloom_backbone_test.py | 3 + .../src/models/gemma/gemma_backbone_test.py | 1 - keras_nlp/src/models/opt/opt_backbone_test.py | 4 + .../pali_gemma/pali_gemma_backbone_test.py | 104 +++++++------- .../src/models/pali_gemma/pali_gemma_vit.py | 1 + keras_nlp/src/tests/test_case.py | 33 ++++- 9 files changed, 288 insertions(+), 72 deletions(-) diff --git a/keras_nlp/src/layers/modeling/reversible_embedding.py b/keras_nlp/src/layers/modeling/reversible_embedding.py index a06dccca59..548a84e842 100644 --- a/keras_nlp/src/layers/modeling/reversible_embedding.py +++ b/keras_nlp/src/layers/modeling/reversible_embedding.py @@ -107,8 +107,7 @@ def __init__( def build(self, inputs_shape=None): super().build(inputs_shape) - - if not self.tie_weights: + if not self.tie_weights and self.quantization_mode != "int8": self.reverse_embeddings = self.add_weight( name="reverse_embeddings", shape=(self.output_dim, self.input_dim), @@ -143,20 +142,28 @@ def save_own_variables(self, store): if not self.built: return super().save_own_variables(store) - # Before Keras 3.2, the reverse weight is saved in the super() call. - # After Keras 3.2, the reverse weight must be saved manually. - if len(store.keys()) < len(self.weights): - # Store the reverse embedding as the last weight. - store[str(len(store.keys()))] = self.reverse_embeddings + target_variables = [] + if not self.tie_weights: + # Store the reverse embedding weights as the last weights. + target_variables.append(self.reverse_embeddings) + if self.quantization_mode == "int8": + target_variables.append(self.reverse_embeddings_scale) + for i, variable in enumerate(target_variables, start=len(store)): + store[str(i)] = variable def load_own_variables(self, store): if not self.built: self.build() super().load_own_variables(store) if not self.tie_weights: - # Last weight in the store is the reverse embedding weights. - key = str(len(store.keys()) - 1) - self.reverse_embeddings.assign(store[key]) + # Last weights in the stores are the reverse embedding weights. + target_variables = [self.reverse_embeddings] + if self.quantization_mode == "int8": + target_variables.append(self.reverse_embeddings_scale) + for i, variable in enumerate( + target_variables, start=len(store) - len(target_variables) + ): + variable.assign(store[str(i)]) def compute_output_spec(self, inputs, reverse=False): output_shape = list(inputs.shape) @@ -165,3 +172,105 @@ def compute_output_spec(self, inputs, reverse=False): else: output_shape += [self.output_dim] return keras.KerasTensor(output_shape, dtype=self.dtype) + + # Quantization-related (int8) methods + + def quantized_call(self, inputs, reverse=False): + # TODO (hongyu): This function could be removed once we add `*args` and + # `**kwargs` for `Embedding.quantized_call` + if self.quantization_mode == "int8": + return self._int8_call(inputs, reverse=reverse) + else: + self._quantization_mode_error(self.quantization_mode) + + def _int8_build( + self, + embeddings_initializer="zeros", + embeddings_scale_initializer="ones", + reverse_embeddings_initializer="zeros", + reverse_embeddings_scale_initializer="ones", + ): + super()._int8_build( + embeddings_initializer, embeddings_scale_initializer + ) + self.inputs_quantizer = keras.quantizers.AbsMaxQuantizer(axis=-1) + if not self.tie_weights: + self.reverse_embeddings = self.add_weight( + name="reverse_embeddings", + shape=(self.output_dim, self.input_dim), + initializer=reverse_embeddings_initializer, + dtype="int8", + trainable=False, + ) + self.reverse_embeddings_scale = self.add_weight( + name="reverse_embeddings_scale", + shape=(self.input_dim,), + initializer=reverse_embeddings_scale_initializer, + trainable=False, + ) + + def _int8_call(self, inputs, reverse=False): + if reverse: + if self.tie_weights: + kernel = ops.transpose(self._embeddings) + scale = ops.transpose(self.embeddings_scale) + else: + kernel = self.reverse_embeddings + scale = self.reverse_embeddings_scale + inputs, inputs_scale = self.inputs_quantizer(inputs) + outputs = ops.matmul(inputs, kernel) + # De-scale outputs + outputs = ops.cast(outputs, self.compute_dtype) + outputs = ops.divide(outputs, ops.multiply(inputs_scale, scale)) + return outputs + + return super()._int8_call(inputs) + + def quantize(self, mode): + import gc + + if type(self) is not ReversibleEmbedding: + raise NotImplementedError( + f"Layer {self.__class__.__name__} does not have a `quantize()` " + "method implemented." + ) + self._check_quantize_args(mode, self.compute_dtype) + + self._tracker.unlock() + if mode == "int8": + embeddings, embeddings_scale = keras.quantizers.abs_max_quantize( + self._embeddings, axis=-1 + ) + embeddings_scale = ops.squeeze(embeddings_scale, axis=-1) + self._untrack_variable(self._embeddings) + del self._embeddings + if not self.tie_weights: + reverse_embeddings, reverse_embeddings_scale = ( + keras.quantizers.abs_max_quantize( + self.reverse_embeddings, axis=0 + ) + ) + reverse_embeddings_scale = ops.squeeze( + reverse_embeddings_scale, axis=0 + ) + self._untrack_variable(self.reverse_embeddings) + del self.reverse_embeddings + else: + reverse_embeddings = None + reverse_embeddings_scale = None + self._int8_build( + lambda shape, dtype: embeddings, + lambda shape, dtype: embeddings_scale, + lambda shape, dtype: reverse_embeddings, + lambda shape, dtype: reverse_embeddings_scale, + ) + else: + raise self._quantization_mode_error(mode) + self._tracker.lock() + + if self.dtype_policy.quantization_mode is None: + policy = keras.dtype_policies.get( + f"{mode}_from_{self.dtype_policy.name}" + ) + self.dtype_policy = policy + gc.collect() diff --git a/keras_nlp/src/layers/modeling/reversible_embedding_test.py b/keras_nlp/src/layers/modeling/reversible_embedding_test.py index 0816f0be6a..ead0cd8ea2 100644 --- a/keras_nlp/src/layers/modeling/reversible_embedding_test.py +++ b/keras_nlp/src/layers/modeling/reversible_embedding_test.py @@ -98,3 +98,71 @@ def test_reverse_dtype(self): output_data = embedding(input_data, reverse=True) self.assertEqual(output_data.shape, (4, 10, 100)) self.assertDTypeEqual(output_data, "float16") + + @parameterized.named_parameters( + ("tie_weights", True), ("untie_weights", False) + ) + def test_quantize_int8(self, tie_weights): + layer_config = dict( + input_dim=100, output_dim=32, tie_weights=tie_weights + ) + layer = ReversibleEmbedding(**layer_config) + layer.build() + x = random.randint(shape=(64, 100), minval=0, maxval=9) + x_reverse = random.uniform(shape=(64, 32)) + y_float = layer(x) + y_reverse_float = layer(x_reverse, reverse=True) + layer.quantize("int8") + + # Verify weights dtype + if not tie_weights: + self.assertEqual( + keras.backend.standardize_dtype(layer.reverse_embeddings.dtype), + "int8", + ) + self.assertEqual( + keras.backend.standardize_dtype( + layer.reverse_embeddings_scale.dtype + ), + layer.variable_dtype, + ) + + # Try eager call and verify output correctness + y_quantized = layer(x) + y_reverse_quantized = layer(x_reverse, reverse=True) + mse = ops.mean(ops.square(y_float - y_quantized)) + mse_reverse = ops.mean( + ops.square(y_reverse_float - y_reverse_quantized) + ) + self.assertLess(mse, 1e-3) # A weak correctness test + self.assertLess(mse_reverse, 1e-3) # A weak correctness test + + # Try saving and reloading the model + model = keras.models.Sequential([layer]) + temp_filepath = os.path.join( + self.get_temp_dir(), "quantized_model.keras" + ) + model.save(temp_filepath) + new_model = keras.models.load_model(temp_filepath) + self.assertAllClose(model.predict(x), new_model.predict(x)) + + @parameterized.named_parameters( + ("tie_weights", True), + ("untie_weights", False), + ) + def test_quantize_dtype_argument(self, tie_weights): + self.run_layer_test( + cls=ReversibleEmbedding, + init_kwargs={ + "input_dim": 100, + "output_dim": 32, + "tie_weights": tie_weights, + "embeddings_initializer": "HeNormal", + "dtype": "int8_from_float32", + }, + input_data=random.randint(minval=0, maxval=100, shape=(4, 10)), + expected_output_shape=(4, 10, 32), + expected_num_trainable_weights=0, + expected_num_non_trainable_weights=2 if tie_weights else 4, + expected_num_non_trainable_variables=2 if tie_weights else 4, + ) diff --git a/keras_nlp/src/models/backbone.py b/keras_nlp/src/models/backbone.py index 468a24f0ca..12e33e98ed 100644 --- a/keras_nlp/src/models/backbone.py +++ b/keras_nlp/src/models/backbone.py @@ -75,11 +75,7 @@ def __init__(self, *args, dtype=None, **kwargs): id(layer) for layer in self._flatten_layers() ) self._initialized = True - if dtype is not None: - if isinstance(dtype, keras.DTypePolicy): - self.dtype_policy = dtype - else: - self.dtype_policy = keras.DTypePolicy(dtype) + self.dtype_policy = keras.dtype_policies.get(dtype) def __setattr__(self, name, value): # Work around setattr issues for Keras 2 and Keras 3 torch backend. @@ -107,11 +103,20 @@ def token_embedding(self, value): def get_config(self): # Don't chain to super here. `get_config()` for functional models is # a nested layer config and cannot be passed to Backbone constructors. - return { + config = { "name": self.name, "trainable": self.trainable, } + # Add quantization support by utilizing `DTypePolicyMap` + policy_map = keras.dtype_policies.DTypePolicyMap() + for layer in self._flatten_layers(): + if layer.quantization_mode is not None: + policy_map[layer.path] = layer.dtype_policy + if len(policy_map) > 0: + config.update({"dtype": policy_map}) + return config + @classmethod def from_config(cls, config): # The default `from_config()` for functional models will return a diff --git a/keras_nlp/src/models/bloom/bloom_backbone_test.py b/keras_nlp/src/models/bloom/bloom_backbone_test.py index 746d1f6f63..6055226cbd 100644 --- a/keras_nlp/src/models/bloom/bloom_backbone_test.py +++ b/keras_nlp/src/models/bloom/bloom_backbone_test.py @@ -39,6 +39,9 @@ def test_backbone_basics(self): init_kwargs=self.init_kwargs, input_data=self.input_data, expected_output_shape=(2, 5, 8), + # TODO: Set to `True`. Error msg: Layer LayerNormalization does not + # have a `quantized_call()` method implemented. + run_quantization_check=False, ) @pytest.mark.large diff --git a/keras_nlp/src/models/gemma/gemma_backbone_test.py b/keras_nlp/src/models/gemma/gemma_backbone_test.py index b6bfc7c71c..74e44abc84 100644 --- a/keras_nlp/src/models/gemma/gemma_backbone_test.py +++ b/keras_nlp/src/models/gemma/gemma_backbone_test.py @@ -171,7 +171,6 @@ def test_distribution_with_lora(self): self.assertEqual(tuple(w.value.sharding.spec), (None, None)) -@pytest.mark.keras_3_only class Gemma2BackboneTest(TestCase): def setUp(self): self.init_kwargs = { diff --git a/keras_nlp/src/models/opt/opt_backbone_test.py b/keras_nlp/src/models/opt/opt_backbone_test.py index 72f3c4b49a..1e31d46582 100644 --- a/keras_nlp/src/models/opt/opt_backbone_test.py +++ b/keras_nlp/src/models/opt/opt_backbone_test.py @@ -40,6 +40,10 @@ def test_backbone_basics(self): init_kwargs=self.init_kwargs, input_data=self.input_data, expected_output_shape=(2, 5, 2), + # TODO: Set to `True`. Error msg: Layer 'token_embedding' expected 1 + # variables, but received 0 variables during loading. Expected: + # ['embeddings'] + run_quantization_check=False, ) @pytest.mark.large diff --git a/keras_nlp/src/models/pali_gemma/pali_gemma_backbone_test.py b/keras_nlp/src/models/pali_gemma/pali_gemma_backbone_test.py index 596b0f2d9d..e39333a7ec 100644 --- a/keras_nlp/src/models/pali_gemma/pali_gemma_backbone_test.py +++ b/keras_nlp/src/models/pali_gemma/pali_gemma_backbone_test.py @@ -33,12 +33,7 @@ def setUp(self): self.vocabulary_size = 256 self.text_sequence_length = 64 self.image_size = 16 - self.dummy_text = [ - "the quick brown fox" for _ in range(self.batch_size) - ] - self.dummy_images = np.random.uniform( - size=(self.batch_size, self.image_size, self.image_size, 3) - ) + self.image_sequence_length = int((self.image_size / 4) ** 2) proto = "gemma_test_vocab.spm" tokenizer = PaliGemmaTokenizer( @@ -48,68 +43,69 @@ def setUp(self): tokenizer, self.text_sequence_length, False, False ) - self.backbone = PaliGemmaBackbone( - vocabulary_size=self.vocabulary_size, - image_size=self.image_size, - num_layers=2, - num_query_heads=2, - num_key_value_heads=1, - hidden_dim=8, - intermediate_dim=16, - head_dim=4, - vit_patch_size=4, - vit_num_layers=2, - vit_num_heads=2, - vit_hidden_dim=8, - vit_intermediate_dim=16, - ) - self.dummy_imgs = np.random.rand( + self.init_kwargs = { + "vocabulary_size": self.vocabulary_size, + "image_size": self.image_size, + "num_layers": 2, + "num_query_heads": 2, + "num_key_value_heads": 1, + "hidden_dim": 8, + "intermediate_dim": 16, + "head_dim": 4, + "vit_patch_size": 4, + "vit_num_layers": 2, + "vit_num_heads": 2, + "vit_hidden_dim": 8, + "vit_intermediate_dim": 16, + } + + dummy_images = np.random.rand( self.batch_size, self.image_size, self.image_size, 3 ) - self.dummy_text_token_ids = np.random.rand( + dummy_text_token_ids = np.random.rand( self.batch_size, self.text_sequence_length ) - self.dummy_text = [ - "answer en the quick brown fox" for i in range(self.batch_size) - ] + dummy_text = ["answer en the quick brown fox"] * self.batch_size + self.input_data = { + "token_ids": dummy_text_token_ids, + "images": dummy_images, + "padding_mask": np.ones( + (self.batch_size, self.text_sequence_length), + dtype="int32", + ), + "response_mask": np.zeros( + (self.batch_size, self.text_sequence_length), + dtype="int32", + ), + } + self.raw_input_data = { + "images": dummy_images, + "prompts": dummy_text, + "responses": dummy_text, + } - def test_pali_gemma_backbone(self): - output = self.backbone( - { - "token_ids": self.dummy_text_token_ids, - "images": self.dummy_imgs, - "padding_mask": np.ones( - (self.batch_size, self.text_sequence_length), - dtype="int32", - ), - "response_mask": np.zeros( - (self.batch_size, self.text_sequence_length), - dtype="int32", - ), - } - ) - self.assertEqual( - ( + def test_backbone_basics(self): + self.run_backbone_test( + cls=PaliGemmaBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=( self.batch_size, - self.text_sequence_length + self.backbone.image_sequence_length, + self.text_sequence_length + self.image_sequence_length, 8, ), - output.shape, + variable_length_data=[self.input_data], + run_mixed_precision_check=False, # TODO: Set to `True` ) def test_pali_gemma_backbone_with_preprocessing(self): - x, _, _ = self.preprocessor( - { - "images": self.dummy_images, - "prompts": self.dummy_text, - "responses": self.dummy_text, - } - ) - output = self.backbone(x) + model = PaliGemmaBackbone(**self.init_kwargs) + x, _, _ = self.preprocessor(self.raw_input_data) + output = model(x) self.assertEqual( ( self.batch_size, - self.text_sequence_length + self.backbone.image_sequence_length, + self.text_sequence_length + self.image_sequence_length, 8, ), output.shape, diff --git a/keras_nlp/src/models/pali_gemma/pali_gemma_vit.py b/keras_nlp/src/models/pali_gemma/pali_gemma_vit.py index 50a1a5eb6c..90b32bdb3d 100644 --- a/keras_nlp/src/models/pali_gemma/pali_gemma_vit.py +++ b/keras_nlp/src/models/pali_gemma/pali_gemma_vit.py @@ -536,6 +536,7 @@ def __init__( classifier_activation ) self.image_sequence_length = int((image_size / patch_size) ** 2) + self.dtype_policy = keras.dtype_policies.get(dtype) def get_config(self): config = super().get_config() diff --git a/keras_nlp/src/tests/test_case.py b/keras_nlp/src/tests/test_case.py index 1031cbb551..8a66345793 100644 --- a/keras_nlp/src/tests/test_case.py +++ b/keras_nlp/src/tests/test_case.py @@ -336,6 +336,30 @@ def run_precision_test(self, cls, init_kwargs, input_data): self.assertEqual(policy.compute_dtype, sublayer.compute_dtype) self.assertEqual(policy.variable_dtype, sublayer.variable_dtype) + def run_quantization_test(self, cls, init_kwargs, input_data): + policy = keras.DTypePolicy("float32") + for mode in ["int8", "float8"]: + layer = cls(**{**init_kwargs, "dtype": policy}) + layer.quantize(mode) + # Try eager call + if isinstance(layer, keras.Model): + _ = layer(input_data) + elif isinstance(input_data, dict): + _ = layer(**input_data) + else: + _ = layer(input_data) + # Verify sublayer's dtype policy + for sublayer in layer._flatten_layers(): + if type(sublayer) is keras.layers.Dense: + self.assertEqual( + f"{mode}_from_float32", sublayer.dtype_policy.name + ) + # Try saving and reloading the model + temp_filepath = os.path.join(self.get_temp_dir(), "layer.keras") + layer.save(temp_filepath) + reloaded_layer = keras.models.load_model(temp_filepath) + self.assertAllClose(layer(input_data), reloaded_layer(input_data)) + def run_model_saving_test( self, cls, @@ -364,6 +388,7 @@ def run_backbone_test( expected_output_shape, variable_length_data=None, run_mixed_precision_check=True, + run_quantization_check=True, ): """Run basic tests for a backbone, including compilation.""" backbone = cls(**init_kwargs) @@ -405,7 +430,13 @@ def run_backbone_test( name = re.sub("([a-z])([A-Z])", r"\1_\2", name).lower() self.assertRegexpMatches(backbone.name, name) - self.run_precision_test(cls, init_kwargs, input_data) + # Check mixed precision. + if run_mixed_precision_check: + self.run_precision_test(cls, init_kwargs, input_data) + + # Check quantization. + if run_quantization_check: + self.run_quantization_test(cls, init_kwargs, input_data) def run_task_test( self,