Skip to content

Commit

Permalink
Add param_axis argument to RMSNorm to allow setting scale param shape.
Browse files Browse the repository at this point in the history
Note: This is mostly a direct fork from the feature in RMSNorm.

PiperOrigin-RevId: 488655300
  • Loading branch information
seb5666 authored and copybara-github committed Nov 15, 2022
1 parent dbc0b1f commit a4998a0
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 4 deletions.
1 change: 1 addition & 0 deletions haiku/_src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -910,6 +910,7 @@ hk_py_library(
deps = [
":base",
":initializers",
":layer_norm",
":module",
# pip: jax
],
Expand Down
31 changes: 27 additions & 4 deletions haiku/_src/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from haiku._src import base
from haiku._src import initializers
from haiku._src import layer_norm
from haiku._src import module
import jax
import jax.numpy as jnp
Expand All @@ -34,6 +35,8 @@
hk.Module = module.Module
del base, module, initializers

AxisOrAxes = Union[int, Sequence[int], slice]


class RMSNorm(hk.Module):
"""RMSNorm module.
Expand All @@ -47,11 +50,14 @@ class RMSNorm(hk.Module):

def __init__(
self,
axis: Union[int, Sequence[int], slice],
axis: AxisOrAxes,
eps: float = 1e-5,
scale_init: Optional[hk.initializers.Initializer] = None,
name: Optional[str] = None,
create_scale: bool = True):
create_scale: bool = True,
*,
param_axis: Optional[AxisOrAxes] = None,
):
"""Constructs a RMSNorm module.
Args:
Expand All @@ -62,6 +68,10 @@ def __init__(
name: The module name.
create_scale: Bool, defines whether to create a trainable scale
per channel applied after the normalization.
param_axis: Axis used to determine the parameter shape of the learnable
scale/offset. Sonnet sets this to the channel/feature axis (e.g. to
``-1`` for ``NHWC``). Other libraries set this to the same as the
reduction axis (e.g. ``axis=param_axis``). `None` defaults to (-1,).
"""
super().__init__(name=name)
if not create_scale and scale_init is not None:
Expand All @@ -79,6 +89,10 @@ def __init__(
self.eps = eps
self.create_scale = create_scale
self.scale_init = scale_init or jnp.ones
if param_axis is None:
self.param_axis = (-1,)
else:
self.param_axis = layer_norm.to_axes_or_slice(param_axis)

def __call__(self, inputs: jnp.ndarray):
"""Connects the layer norm.
Expand All @@ -93,9 +107,18 @@ def __call__(self, inputs: jnp.ndarray):
if isinstance(axis, slice):
axis = tuple(range(inputs.ndim)[axis])

param_axis = layer_norm.to_abs_axes(self.param_axis, inputs.ndim)
if param_axis == (inputs.ndim - 1,):
# For param_axis=-1 we store non-broadcast param shape for compatibility
# with older checkpoints.
param_shape = inputs.shape[-1:]
else:
param_shape = tuple(
(inputs.shape[i] if i in param_axis else 1)
for i in range(inputs.ndim))
if self.create_scale:
scale = hk.get_parameter("scale", inputs.shape[-1:], inputs.dtype,
init=self.scale_init)
scale = hk.get_parameter(
"scale", param_shape, inputs.dtype, init=self.scale_init)
scale = jnp.broadcast_to(scale, inputs.shape)
else:
scale = 1.
Expand Down
38 changes: 38 additions & 0 deletions haiku/_src/rms_norm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ def test_connection(self):
norms.append(rms_norm.RMSNorm(axis=slice(2, None))(data))
norms.append(rms_norm.RMSNorm(axis=slice(1, -1))(data))

norms.append(rms_norm.RMSNorm(axis=-1, param_axis=(-1,))(data))
norms.append(rms_norm.RMSNorm(axis=-1, param_axis=(-2, -1))(data))
norms.append(rms_norm.RMSNorm(axis=-1, param_axis=(0, 1))(data))
norms.append(rms_norm.RMSNorm(axis=-1, param_axis=(0, 1, 2, 3))(data))
return norms

def test_bf16(self):
Expand Down Expand Up @@ -104,5 +108,39 @@ def test_simple_case_without_scale(self):
_ = layer(inputs)
assert "scale" not in layer.params_dict()

@parameterized.parameters(
(None, (6,)),
(-1, (6,)),
(-2, (1, 1, 5, 1)),
(-3, (1, 4, 1, 1)),
(-4, (3, 1, 1, 1)),
(0, (3, 1, 1, 1)),
(1, (1, 4, 1, 1)),
(2, (1, 1, 5, 1)),
(3, (6,)),
(slice(1, 3), (1, 4, 5, 1)),
(slice(0, 3, 2), (3, 1, 5, 1)),
(slice(-1, 0, -1), (1, 4, 5, 6)),
)
@test_utils.transform_and_run
def test_param_axis_sets_param_shape(self, param_axis, param_shape):
ln = rms_norm.RMSNorm(axis=-1, param_axis=param_axis)
ln(jnp.ones([3, 4, 5, 6]))
self.assertEqual(ln.params_dict()["rms_norm/scale"].shape, param_shape)

@parameterized.parameters(
((0, 1, 2), (3, 4, 5, 1)),
((-4, -2, -3), (3, 4, 5, 1)),
((0, 1), (3, 4, 1, 1)),
((0, 3), (3, 1, 1, 6)),
((-4, -1), (3, 1, 1, 6)),
((-1, -4), (3, 1, 1, 6)),
)
@test_utils.transform_and_run
def test_multiple_param_axis(self, param_axis, param_shape):
ln = rms_norm.RMSNorm(axis=-1, param_axis=param_axis)
ln(jnp.ones([3, 4, 5, 6]))
self.assertEqual(ln.params_dict()["rms_norm/scale"].shape, param_shape)

if __name__ == "__main__":
absltest.main()

0 comments on commit a4998a0

Please sign in to comment.