Skip to content

Commit

Permalink
Add quantization support for Gemma, Gemma2 and PaliGemma (#1670)
Browse files Browse the repository at this point in the history
* Introduce quantization support to `Gemma`

* Revert `SentencePieceTokenizer`

* Add tests for `PaliGemma`

* Add quantization support for Gemma2

* Address comments
  • Loading branch information
james77777778 authored Jul 3, 2024
1 parent 4884660 commit bb423c8
Show file tree
Hide file tree
Showing 9 changed files with 288 additions and 72 deletions.
129 changes: 119 additions & 10 deletions keras_nlp/src/layers/modeling/reversible_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,7 @@ def __init__(

def build(self, inputs_shape=None):
super().build(inputs_shape)

if not self.tie_weights:
if not self.tie_weights and self.quantization_mode != "int8":
self.reverse_embeddings = self.add_weight(
name="reverse_embeddings",
shape=(self.output_dim, self.input_dim),
Expand Down Expand Up @@ -143,20 +142,28 @@ def save_own_variables(self, store):
if not self.built:
return
super().save_own_variables(store)
# Before Keras 3.2, the reverse weight is saved in the super() call.
# After Keras 3.2, the reverse weight must be saved manually.
if len(store.keys()) < len(self.weights):
# Store the reverse embedding as the last weight.
store[str(len(store.keys()))] = self.reverse_embeddings
target_variables = []
if not self.tie_weights:
# Store the reverse embedding weights as the last weights.
target_variables.append(self.reverse_embeddings)
if self.quantization_mode == "int8":
target_variables.append(self.reverse_embeddings_scale)
for i, variable in enumerate(target_variables, start=len(store)):
store[str(i)] = variable

def load_own_variables(self, store):
if not self.built:
self.build()
super().load_own_variables(store)
if not self.tie_weights:
# Last weight in the store is the reverse embedding weights.
key = str(len(store.keys()) - 1)
self.reverse_embeddings.assign(store[key])
# Last weights in the stores are the reverse embedding weights.
target_variables = [self.reverse_embeddings]
if self.quantization_mode == "int8":
target_variables.append(self.reverse_embeddings_scale)
for i, variable in enumerate(
target_variables, start=len(store) - len(target_variables)
):
variable.assign(store[str(i)])

def compute_output_spec(self, inputs, reverse=False):
output_shape = list(inputs.shape)
Expand All @@ -165,3 +172,105 @@ def compute_output_spec(self, inputs, reverse=False):
else:
output_shape += [self.output_dim]
return keras.KerasTensor(output_shape, dtype=self.dtype)

# Quantization-related (int8) methods

def quantized_call(self, inputs, reverse=False):
# TODO (hongyu): This function could be removed once we add `*args` and
# `**kwargs` for `Embedding.quantized_call`
if self.quantization_mode == "int8":
return self._int8_call(inputs, reverse=reverse)
else:
self._quantization_mode_error(self.quantization_mode)

def _int8_build(
self,
embeddings_initializer="zeros",
embeddings_scale_initializer="ones",
reverse_embeddings_initializer="zeros",
reverse_embeddings_scale_initializer="ones",
):
super()._int8_build(
embeddings_initializer, embeddings_scale_initializer
)
self.inputs_quantizer = keras.quantizers.AbsMaxQuantizer(axis=-1)
if not self.tie_weights:
self.reverse_embeddings = self.add_weight(
name="reverse_embeddings",
shape=(self.output_dim, self.input_dim),
initializer=reverse_embeddings_initializer,
dtype="int8",
trainable=False,
)
self.reverse_embeddings_scale = self.add_weight(
name="reverse_embeddings_scale",
shape=(self.input_dim,),
initializer=reverse_embeddings_scale_initializer,
trainable=False,
)

def _int8_call(self, inputs, reverse=False):
if reverse:
if self.tie_weights:
kernel = ops.transpose(self._embeddings)
scale = ops.transpose(self.embeddings_scale)
else:
kernel = self.reverse_embeddings
scale = self.reverse_embeddings_scale
inputs, inputs_scale = self.inputs_quantizer(inputs)
outputs = ops.matmul(inputs, kernel)
# De-scale outputs
outputs = ops.cast(outputs, self.compute_dtype)
outputs = ops.divide(outputs, ops.multiply(inputs_scale, scale))
return outputs

return super()._int8_call(inputs)

def quantize(self, mode):
import gc

if type(self) is not ReversibleEmbedding:
raise NotImplementedError(
f"Layer {self.__class__.__name__} does not have a `quantize()` "
"method implemented."
)
self._check_quantize_args(mode, self.compute_dtype)

self._tracker.unlock()
if mode == "int8":
embeddings, embeddings_scale = keras.quantizers.abs_max_quantize(
self._embeddings, axis=-1
)
embeddings_scale = ops.squeeze(embeddings_scale, axis=-1)
self._untrack_variable(self._embeddings)
del self._embeddings
if not self.tie_weights:
reverse_embeddings, reverse_embeddings_scale = (
keras.quantizers.abs_max_quantize(
self.reverse_embeddings, axis=0
)
)
reverse_embeddings_scale = ops.squeeze(
reverse_embeddings_scale, axis=0
)
self._untrack_variable(self.reverse_embeddings)
del self.reverse_embeddings
else:
reverse_embeddings = None
reverse_embeddings_scale = None
self._int8_build(
lambda shape, dtype: embeddings,
lambda shape, dtype: embeddings_scale,
lambda shape, dtype: reverse_embeddings,
lambda shape, dtype: reverse_embeddings_scale,
)
else:
raise self._quantization_mode_error(mode)
self._tracker.lock()

if self.dtype_policy.quantization_mode is None:
policy = keras.dtype_policies.get(
f"{mode}_from_{self.dtype_policy.name}"
)
self.dtype_policy = policy
gc.collect()
68 changes: 68 additions & 0 deletions keras_nlp/src/layers/modeling/reversible_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,71 @@ def test_reverse_dtype(self):
output_data = embedding(input_data, reverse=True)
self.assertEqual(output_data.shape, (4, 10, 100))
self.assertDTypeEqual(output_data, "float16")

@parameterized.named_parameters(
("tie_weights", True), ("untie_weights", False)
)
def test_quantize_int8(self, tie_weights):
layer_config = dict(
input_dim=100, output_dim=32, tie_weights=tie_weights
)
layer = ReversibleEmbedding(**layer_config)
layer.build()
x = random.randint(shape=(64, 100), minval=0, maxval=9)
x_reverse = random.uniform(shape=(64, 32))
y_float = layer(x)
y_reverse_float = layer(x_reverse, reverse=True)
layer.quantize("int8")

# Verify weights dtype
if not tie_weights:
self.assertEqual(
keras.backend.standardize_dtype(layer.reverse_embeddings.dtype),
"int8",
)
self.assertEqual(
keras.backend.standardize_dtype(
layer.reverse_embeddings_scale.dtype
),
layer.variable_dtype,
)

# Try eager call and verify output correctness
y_quantized = layer(x)
y_reverse_quantized = layer(x_reverse, reverse=True)
mse = ops.mean(ops.square(y_float - y_quantized))
mse_reverse = ops.mean(
ops.square(y_reverse_float - y_reverse_quantized)
)
self.assertLess(mse, 1e-3) # A weak correctness test
self.assertLess(mse_reverse, 1e-3) # A weak correctness test

# Try saving and reloading the model
model = keras.models.Sequential([layer])
temp_filepath = os.path.join(
self.get_temp_dir(), "quantized_model.keras"
)
model.save(temp_filepath)
new_model = keras.models.load_model(temp_filepath)
self.assertAllClose(model.predict(x), new_model.predict(x))

@parameterized.named_parameters(
("tie_weights", True),
("untie_weights", False),
)
def test_quantize_dtype_argument(self, tie_weights):
self.run_layer_test(
cls=ReversibleEmbedding,
init_kwargs={
"input_dim": 100,
"output_dim": 32,
"tie_weights": tie_weights,
"embeddings_initializer": "HeNormal",
"dtype": "int8_from_float32",
},
input_data=random.randint(minval=0, maxval=100, shape=(4, 10)),
expected_output_shape=(4, 10, 32),
expected_num_trainable_weights=0,
expected_num_non_trainable_weights=2 if tie_weights else 4,
expected_num_non_trainable_variables=2 if tie_weights else 4,
)
17 changes: 11 additions & 6 deletions keras_nlp/src/models/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,7 @@ def __init__(self, *args, dtype=None, **kwargs):
id(layer) for layer in self._flatten_layers()
)
self._initialized = True
if dtype is not None:
if isinstance(dtype, keras.DTypePolicy):
self.dtype_policy = dtype
else:
self.dtype_policy = keras.DTypePolicy(dtype)
self.dtype_policy = keras.dtype_policies.get(dtype)

def __setattr__(self, name, value):
# Work around setattr issues for Keras 2 and Keras 3 torch backend.
Expand Down Expand Up @@ -107,11 +103,20 @@ def token_embedding(self, value):
def get_config(self):
# Don't chain to super here. `get_config()` for functional models is
# a nested layer config and cannot be passed to Backbone constructors.
return {
config = {
"name": self.name,
"trainable": self.trainable,
}

# Add quantization support by utilizing `DTypePolicyMap`
policy_map = keras.dtype_policies.DTypePolicyMap()
for layer in self._flatten_layers():
if layer.quantization_mode is not None:
policy_map[layer.path] = layer.dtype_policy
if len(policy_map) > 0:
config.update({"dtype": policy_map})
return config

@classmethod
def from_config(cls, config):
# The default `from_config()` for functional models will return a
Expand Down
3 changes: 3 additions & 0 deletions keras_nlp/src/models/bloom/bloom_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def test_backbone_basics(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 8),
# TODO: Set to `True`. Error msg: Layer LayerNormalization does not
# have a `quantized_call()` method implemented.
run_quantization_check=False,
)

@pytest.mark.large
Expand Down
1 change: 0 additions & 1 deletion keras_nlp/src/models/gemma/gemma_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ def test_distribution_with_lora(self):
self.assertEqual(tuple(w.value.sharding.spec), (None, None))


@pytest.mark.keras_3_only
class Gemma2BackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
Expand Down
4 changes: 4 additions & 0 deletions keras_nlp/src/models/opt/opt_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def test_backbone_basics(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(2, 5, 2),
# TODO: Set to `True`. Error msg: Layer 'token_embedding' expected 1
# variables, but received 0 variables during loading. Expected:
# ['embeddings']
run_quantization_check=False,
)

@pytest.mark.large
Expand Down
Loading

0 comments on commit bb423c8

Please sign in to comment.