Skip to content

Commit

Permalink
Layout map for Llama (keras-team#1923)
Browse files Browse the repository at this point in the history
* added test for the way BytePairTokenizer handles the \n\n sequence, which is important in Lama chat templates

* un commented the test lines that were commented by mistake

* fixed linter errors

* added default layout map for Llama

* minor fixes in tests
  • Loading branch information
martin-gorner authored and ushareng committed Oct 24, 2024
1 parent d422e0f commit d46754a
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 8 deletions.
14 changes: 11 additions & 3 deletions keras_hub/src/models/gemma/gemma_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,19 +224,27 @@ def get_layout_map(
Example:
```
# Feel free to change the mesh shape to balance data and model parallel
# Feel free to change the mesh shape to balance data and model parallelism
mesh = keras.distribution.DeviceMesh(
shape=(1, 8), axis_names=('batch', 'model'),
devices=keras.distribution.list_devices())
layout_map = GemmaBackbone.get_layout_map(
mesh, model_parallel_dim_name="model")
distribution = keras.distribution.ModelParallel(
mesh, layout_map, batch_dim_name='batch')
layout_map=layout_map, batch_dim_name='batch')
with distribution.scope():
gemma_model = keras_hub.models.GemmaCausalLM.from_preset()
```
To see how the layout map was applied, load the model then run (for one decoder block):
```
embedding_layer = gemma_model.backbone.get_layer("token_embedding")
decoder_block_1 = gemma_model.backbone.get_layer('decoder_block_1')
for variable in embedding_layer.weights + decoder_block_1.weights:
print(f'{variable.path:<58} {str(variable.shape):<16} {str(variable.value.sharding.spec)}')
```
Args:
device_mesh: The `keras.distribution.DeviceMesh` instance for
distribution.
Expand All @@ -246,7 +254,7 @@ def get_layout_map(
the data should be partition on.
Return:
`keras.distribution.LayoutMap` that contains the sharding spec
of all the model weights.
for all the model weights.
"""
# The weight path and shape of the Gemma backbone is like below (for 2G)
# token_embedding/embeddings, (256128, 2048), 524550144
Expand Down
8 changes: 3 additions & 5 deletions keras_hub/src/models/gemma/gemma_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,19 +74,18 @@ def test_architecture_characteristics(self):

def test_distribution(self):
if keras.backend.backend() != "jax":
return
self.skipTest("`ModelParallel` testing requires the Jax backend.")
devices = keras.distribution.list_devices("CPU")
if len(devices) == 1:
# Need more than 1 device for distribution testing.
return
self.skipTest("`ModelParallel` testing requires multiple devices.")
device_mesh = keras.distribution.DeviceMesh(
shape=(1, len(devices)),
axis_names=("batch", "model"),
devices=devices,
)

layout_map = GemmaBackbone.get_layout_map(device_mesh)
distribution = keras.distribution.ModelParallel(device_mesh, layout_map)
distribution = keras.distribution.ModelParallel(layout_map=layout_map)
with distribution.scope():
model = GemmaBackbone(**self.init_kwargs)

Expand Down Expand Up @@ -129,7 +128,6 @@ def test_distribution_with_lora(self):
self.skipTest("`ModelParallel` testing requires the Jax backend.")
devices = keras.distribution.list_devices("CPU")
if len(devices) == 1:
# Need more than 1 device for distribution testing.
self.skipTest("`ModelParallel` testing requires multiple devices.")
device_mesh = keras.distribution.DeviceMesh(
shape=(1, len(devices)),
Expand Down
118 changes: 118 additions & 0 deletions keras_hub/src/models/llama/llama_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,121 @@ def get_config(self):
}
)
return config

@staticmethod
def get_layout_map(
device_mesh,
model_parallel_dim_name="model",
data_parallel_dim_name="batch",
):
"""Get a `keras.distribution.LayoutMap` for model parallel distribution.
The returned `LayoutMap` contains the sharding spec for the Llama
backbone weights, so that you can use it to distribute weights across
the accelerators.
Example:
```
# Feel free to change the mesh shape to balance data and model parallelism
mesh = keras.distribution.DeviceMesh(
shape=(1, 8),
axis_names=('batch', 'model'),
devices=keras.distribution.list_devices(),
)
layout_map = LlamaBackbone.get_layout_map(
mesh,
model_parallel_dim_name="model",
)
distribution = keras.distribution.ModelParallel(
layout_map=layout_map,
batch_dim_name='batch',
)
with distribution.scope():
llama_model = keras_hub.models.LlamaCausalLM.from_preset()
```
To see how the layout map was applied, load the model then run (for one decoder block):
```
embedding_layer = llama_model.backbone.get_layer("token_embedding")
decoder_block_1 = llama_model.backbone.get_layer('transformer_layer_0')
for variable in embedding_layer.weights + decoder_block_1.weights:
print(f'{variable.path:<58} {str(variable.shape):<16} {str(variable.value.sharding.spec)}')
```
Args:
device_mesh: The `keras.distribution.DeviceMesh` instance for
distribution.
model_parallel_dim_name: The axis name of the device mesh, where
the weights should be partition on.
data_parallel_dim_name: The axis name of the device mesh, where
the data should be partition on.
Return:
`keras.distribution.LayoutMap` that contains the sharding spec
for all the model weights.
"""
# The weight path and shape of the Llama backbone is like below
# token_embedding/embeddings (128256, 2048)
# repeat block for decoder
# transformer_layer_0/self_attention/query/kernel (2048, 32, 64)
# transformer_layer_0/self_attention/key/kernel (2048, 8, 64)
# transformer_layer_0/self_attention/value/kernel (2048, 8, 64)
# transformer_layer_0/self_attention/attention_output/kernel (32, 64, 2048)
# transformer_layer_0/self_attention_layernorm/scale (2048,)
# transformer_layer_0/feedforward_intermediate_dense/kernel (2048, 8192)
# transformer_layer_0/feedforward_gate_dense/kernel (2048, 8192)
# transformer_layer_0/feedforward_output_dense/kernel (8192, 2048)
# transformer_layer_0/feedforward_layernorm/scale (2048,)

if not isinstance(device_mesh, keras.distribution.DeviceMesh):
raise ValueError(
"Invalid device_mesh type. Expected `keras.distribution.Device`,"
f" got {type(device_mesh)}"
)
if model_parallel_dim_name not in device_mesh.axis_names:
raise ValueError(
f"{model_parallel_dim_name} is not found in the "
f"device_mesh.axis_names. {device_mesh.axis_name=}"
)
if data_parallel_dim_name not in device_mesh.axis_names:
raise ValueError(
f"{data_parallel_dim_name} is not found in the "
f"device_mesh.axis_names. {device_mesh.axis_name=}"
)
# Note that it is possible to further config the mesh to be 3D, eg
# (data, seq, model). We leave it as 2D for now for simplicity.
data_dim = data_parallel_dim_name
model_dim = model_parallel_dim_name
# The sharding config is based on the Gemma team training config.
# See https://arxiv.org/abs/2403.08295
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (model_dim, data_dim)
layout_map[
"transformer_layer.*self_attention.*(query|key|value).kernel"
] = (
model_dim,
data_dim,
None,
)
layout_map["transformer_layer.*attention_output.kernel"] = (
model_dim,
None,
data_dim,
)
layout_map[
"transformer_layer.*feedforward_intermediate_dense.kernel"
] = (
data_dim,
model_dim,
)
layout_map["transformer_layer.*feedforward_gate_dense.kernel"] = (
data_dim,
model_dim,
)
layout_map["transformer_layer.*feedforward_output_dense.kernel"] = (
model_dim,
data_dim,
)

return layout_map
85 changes: 85 additions & 0 deletions keras_hub/src/models/llama/llama_backbone_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import keras
import pytest
from keras import ops

Expand Down Expand Up @@ -66,3 +67,87 @@ def test_all_presets(self):
preset=preset,
input_data=self.input_data,
)

def test_distribution(self):
if keras.backend.backend() != "jax":
self.skipTest("`ModelParallel` testing requires the Jax backend.")
devices = keras.distribution.list_devices("CPU")
if len(devices) == 1:
self.skipTest("`ModelParallel` testing requires multiple devices.")
device_mesh = keras.distribution.DeviceMesh(
shape=(1, len(devices)),
axis_names=("batch", "model"),
devices=devices,
)

layout_map = LlamaBackbone.get_layout_map(device_mesh)
distribution = keras.distribution.ModelParallel(layout_map=layout_map)
with distribution.scope():
model = LlamaBackbone(**self.init_kwargs)

for w in model.weights:
if "token_embedding/embeddings" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), ("model", "batch")
)
if "self_attention/query/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), ("model", "batch", None)
)
if "self_attention/key/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), ("model", "batch", None)
)
if "self_attention/value/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), ("model", "batch", None)
)
if "self_attention/attention_output/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), ("model", None, "batch")
)
if "feedforward_intermediate_dense/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), ("batch", "model")
)
if "feedforward_gate_dense/kernel" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), ("batch", "model")
)
if "feedforward_output_dense" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), ("model", "batch")
)

def test_distribution_with_lora(self):
if keras.backend.backend() != "jax":
self.skipTest("`ModelParallel` testing requires the Jax backend.")
devices = keras.distribution.list_devices("CPU")
if len(devices) == 1:
# Need more than 1 device for distribution testing.
self.skipTest("`ModelParallel` testing requires multiple devices.")
device_mesh = keras.distribution.DeviceMesh(
shape=(1, len(devices)),
axis_names=("batch", "model"),
devices=devices,
)

layout_map = LlamaBackbone.get_layout_map(device_mesh)
distribution = keras.distribution.ModelParallel(layout_map=layout_map)
with distribution.scope():
model = LlamaBackbone(**self.init_kwargs)
model.enable_lora(rank=4)

for w in model.weights:
if "self_attention/query/lora_kernel_a" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, None, None)
)
if "self_attention/query/lora_kernel_b" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), (None, None))
if "self_attention/value/lora_kernel_a" in w.path:
self.assertEqual(
tuple(w.value.sharding.spec), (None, None, None)
)
if "self_attention/value/lora_kernel_b" in w.path:
self.assertEqual(tuple(w.value.sharding.spec), (None, None))

0 comments on commit d46754a

Please sign in to comment.