Skip to content

Commit ff427e5

Browse files
[Keras Ops and Layer] Add keras.ops.rms_norm() and keras.layers.RMSNormalization() (#20911)
* Add RMSNorm and rms_norm * math.square -> numpy.square * Update docstrings * Add RMSNormalization Layer * Update docstrings * Lint with new ruff version * Add tests for layer * Address comments * Convert to tensor if not - avoid openvino and torch typing issues if scale is scalar * address comments * Fix tests * Add reference to paper * Fix docstring to remove input_dim argument * Update layer_normalization.py --------- Co-authored-by: François Chollet <[email protected]>
1 parent f7115c2 commit ff427e5

File tree

12 files changed

+260
-1
lines changed

12 files changed

+260
-1
lines changed

keras/api/_tf_keras/keras/layers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
from keras.src.layers.normalization.layer_normalization import (
8484
LayerNormalization,
8585
)
86+
from keras.src.layers.normalization.rms_normalization import RMSNormalization
8687
from keras.src.layers.normalization.spectral_normalization import (
8788
SpectralNormalization,
8889
)

keras/api/_tf_keras/keras/ops/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
from keras.src.ops.nn import psnr
9292
from keras.src.ops.nn import relu
9393
from keras.src.ops.nn import relu6
94+
from keras.src.ops.nn import rms_normalization
9495
from keras.src.ops.nn import selu
9596
from keras.src.ops.nn import separable_conv
9697
from keras.src.ops.nn import sigmoid

keras/api/_tf_keras/keras/ops/nn/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from keras.src.ops.nn import psnr
3636
from keras.src.ops.nn import relu
3737
from keras.src.ops.nn import relu6
38+
from keras.src.ops.nn import rms_normalization
3839
from keras.src.ops.nn import selu
3940
from keras.src.ops.nn import separable_conv
4041
from keras.src.ops.nn import sigmoid

keras/api/layers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
from keras.src.layers.normalization.layer_normalization import (
8484
LayerNormalization,
8585
)
86+
from keras.src.layers.normalization.rms_normalization import RMSNormalization
8687
from keras.src.layers.normalization.spectral_normalization import (
8788
SpectralNormalization,
8889
)

keras/api/ops/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
from keras.src.ops.nn import psnr
9292
from keras.src.ops.nn import relu
9393
from keras.src.ops.nn import relu6
94+
from keras.src.ops.nn import rms_normalization
9495
from keras.src.ops.nn import selu
9596
from keras.src.ops.nn import separable_conv
9697
from keras.src.ops.nn import sigmoid

keras/api/ops/nn/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from keras.src.ops.nn import psnr
3636
from keras.src.ops.nn import relu
3737
from keras.src.ops.nn import relu6
38+
from keras.src.ops.nn import rms_normalization
3839
from keras.src.ops.nn import selu
3940
from keras.src.ops.nn import separable_conv
4041
from keras.src.ops.nn import sigmoid

keras/src/layers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from keras.src.layers.normalization.layer_normalization import (
5858
LayerNormalization,
5959
)
60+
from keras.src.layers.normalization.rms_normalization import RMSNormalization
6061
from keras.src.layers.normalization.spectral_normalization import (
6162
SpectralNormalization,
6263
)

keras/src/layers/normalization/layer_normalization.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ class LayerNormalization(Layer):
8686
rms_scaling: If True, `center` and `scale` are ignored, and the
8787
inputs are scaled by `gamma` and the inverse square root
8888
of the square of all inputs. This is an approximate and faster
89-
approach that avoids ever computing the mean of the input.
89+
approach that avoids ever computing the mean of the input. Note that
90+
this *isn't* equivalent to the computation that the
91+
`keras.layers.RMSNormalization` layer performs.
9092
beta_initializer: Initializer for the beta weight. Defaults to zeros.
9193
gamma_initializer: Initializer for the gamma weight. Defaults to ones.
9294
beta_regularizer: Optional regularizer for the beta weight.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from keras.src import ops
2+
from keras.src.api_export import keras_export
3+
from keras.src.layers.layer import Layer
4+
5+
6+
@keras_export("keras.layers.RMSNormalization")
7+
class RMSNormalization(Layer):
8+
"""Root Mean Square (RMS) Normalization layer.
9+
10+
This layer normalizes the input tensor based on its RMS value.
11+
12+
The Keras layer performs the operation as described in
13+
[Root Mean Square Layer Normalization](https://arxiv.org/pdf/1910.07467)
14+
by Biao Zhang et al.
15+
16+
17+
If `scale` is enabled, the layer will scale the normalized outputs via
18+
a learnable scaling factor.
19+
20+
So, with scaling enabled, the normalization equations
21+
are as follows:
22+
23+
Let the intermediate activations for a mini-batch to be the `inputs`.
24+
25+
```python
26+
rms_normalization(x) = x * rsqrt(mean(square(x))) * scale
27+
```
28+
29+
For example:
30+
31+
>>> layer = keras.layers.RMSNormalization()
32+
>>> layer.build([5, 20, 30, 10])
33+
>>> print(layer.scale.shape)
34+
(10,)
35+
>>> layer(np.random.rand(1, 10)).numpy()
36+
array([[0.35098287, 1.0495652 , 1.4645109 , 1.2944688 , 0.31124955,
37+
1.2768592 , 1.184331 , 0.17474432, 0.49955517, 1.2428929 ]],
38+
dtype=float32)
39+
40+
Args:
41+
axis: int. The axis on which to perform the normalization.
42+
epsilon: float. A small number to add to avoid division by zero.
43+
"""
44+
45+
def __init__(self, axis=-1, epsilon=1e-6, **kwargs):
46+
super().__init__(**kwargs)
47+
self.axis = axis
48+
self.epsilon = epsilon
49+
50+
def build(self, input_shape):
51+
if isinstance(self.axis, list):
52+
shape = tuple([input_shape[dim] for dim in self.axis])
53+
else:
54+
shape = (input_shape[self.axis],)
55+
self.axis = [self.axis]
56+
57+
self.scale = self.add_weight(
58+
name="scale", shape=shape, initializer="ones"
59+
)
60+
61+
self.built = True
62+
63+
def call(self, x):
64+
"""Applies RMS normalization to the input tensor.
65+
66+
Args:
67+
x: Input tensor of shape (batch_size, input_dim).
68+
69+
Returns:
70+
The RMS-normalized tensor of the same shape (batch_size, input_dim),
71+
scaled by the learned `scale` parameter.
72+
"""
73+
return ops.rms_normalization(
74+
x, scale=self.scale, axis=self.axis, epsilon=self.epsilon
75+
)
76+
77+
def compute_output_shape(self, input_shape):
78+
if isinstance(self.axis, int):
79+
axes = [self.axis]
80+
else:
81+
axes = self.axis
82+
83+
for axis in axes:
84+
if axis >= len(input_shape) or axis < -len(input_shape):
85+
raise ValueError(
86+
f"Axis {axis} is out of bounds for "
87+
f"input shape {input_shape}. "
88+
f"Received: axis={self.axis}"
89+
)
90+
return input_shape
91+
92+
def get_config(self):
93+
config = {
94+
"axis": self.axis,
95+
"epsilon": self.epsilon,
96+
}
97+
base_config = super().get_config()
98+
return {**base_config, **config}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import numpy as np
2+
import pytest
3+
4+
from keras.src import layers
5+
from keras.src import ops
6+
from keras.src import testing
7+
8+
9+
class RMSNormalizationTest(testing.TestCase):
10+
@pytest.mark.requires_trainable_backend
11+
def test_ln_basics(self):
12+
self.run_layer_test(
13+
layers.RMSNormalization,
14+
init_kwargs={},
15+
input_shape=(4, 2),
16+
expected_output_shape=(4, 2),
17+
expected_num_trainable_weights=1,
18+
expected_num_seed_generators=0,
19+
)
20+
self.run_layer_test(
21+
layers.RMSNormalization,
22+
init_kwargs={
23+
"axis": -1,
24+
},
25+
input_shape=(4, 2),
26+
expected_output_shape=(4, 2),
27+
expected_num_trainable_weights=1,
28+
expected_num_seed_generators=0,
29+
)
30+
31+
def test_correctness(self):
32+
layer = layers.RMSNormalization()
33+
layer.build(input_shape=(2, 2, 2))
34+
inputs = np.random.normal(
35+
loc=5.0, scale=10.0, size=(1000, 2, 2, 2)
36+
).astype("float32")
37+
38+
inputs = ops.convert_to_tensor(inputs)
39+
40+
out = layer(inputs)
41+
expected = (
42+
inputs
43+
* ops.rsqrt(ops.mean(ops.square(inputs), axis=-1, keepdims=True))
44+
* layer.scale
45+
)
46+
47+
self.assertAllClose(out, expected, atol=1e-1)
48+
49+
def test_output(self):
50+
layer = layers.RMSNormalization()
51+
inputs = np.arange(10).astype("float32")[None, :]
52+
out = layer(inputs)
53+
self.assertAllClose(
54+
out,
55+
[
56+
[
57+
0.0,
58+
0.18731716,
59+
0.37463433,
60+
0.5619515,
61+
0.74926865,
62+
0.9365858,
63+
1.123903,
64+
1.3112202,
65+
1.4985373,
66+
1.6858544,
67+
]
68+
],
69+
)

keras/src/ops/nn.py

+77
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from keras.src.backend.common.backend_utils import (
1111
compute_conv_transpose_output_shape,
1212
)
13+
from keras.src.backend.common.keras_tensor import is_keras_tensor
1314
from keras.src.ops import operation_utils
1415
from keras.src.ops.operation import Operation
1516
from keras.src.ops.operation_utils import reduce_shape
@@ -2653,6 +2654,82 @@ def dot_product_attention(
26532654
)
26542655

26552656

2657+
class RMSNorm(Operation):
2658+
def __init__(self, scale, axis=-1, epsilon=None):
2659+
super().__init__()
2660+
self.axis = axis
2661+
self.scale = scale
2662+
self.epsilon = epsilon
2663+
2664+
def compute_output_spec(self, x):
2665+
return KerasTensor(shape=x.shape)
2666+
2667+
def call(self, x):
2668+
return _rms_normalization(
2669+
x, scale=self.scale, axis=self.axis, epsilon=self.epsilon
2670+
)
2671+
2672+
2673+
@keras_export(
2674+
[
2675+
"keras.ops.rms_normalization",
2676+
"keras.ops.nn.rms_normalization",
2677+
]
2678+
)
2679+
def rms_normalization(x, scale=1, axis=-1, epsilon=None):
2680+
"""Performs Root Mean Square (RMS) normalization on `x`.
2681+
2682+
The Keras operation implements the operation as described in
2683+
[Root Mean Square Layer Normalization](https://arxiv.org/pdf/1910.07467)
2684+
by Biao Zhang et al.
2685+
2686+
The operation is different from LayerNormalization with RMS scaling.
2687+
2688+
It is defined as `rms_normalization(x) = x * rsqrt(mean(square(x))) * scale`
2689+
2690+
Args:
2691+
x: Input tensor.
2692+
axis: The axis or axes along which to perform normalization.
2693+
Default to -1.
2694+
scale: Optional scaling factor for the normalization.
2695+
epsilon: A lower bound value for the norm.
2696+
Defaults to `backend.epsilon()`.
2697+
2698+
Returns:
2699+
The normalized array.
2700+
2701+
Example:
2702+
2703+
>>> x = np.random.rand(1, 10)
2704+
>>> x_norm = keras.ops.rms_normalization(x, (10,))
2705+
>>> print(x_norm)
2706+
array([[0.69384296, 0.94444374, 0.16551171, 0.05749961, 1.11008865,
2707+
0.52475186, 1.57686807, 1.69893307, 1.27292764, 0.30819128]])
2708+
"""
2709+
if any_symbolic_tensors((x,)):
2710+
return RMSNorm(scale=scale, axis=axis, epsilon=epsilon).symbolic_call(x)
2711+
return _rms_normalization(x, scale=scale, axis=axis, epsilon=epsilon)
2712+
2713+
2714+
def _rms_normalization(x, scale=1, axis=-1, epsilon=None):
2715+
x = backend.convert_to_tensor(x)
2716+
if len(x.shape) == 0:
2717+
x = backend.numpy.expand_dims(x, axis=0)
2718+
if epsilon is None:
2719+
epsilon = backend.epsilon()
2720+
2721+
if not is_keras_tensor(scale):
2722+
scale = backend.convert_to_tensor(scale, dtype=x.dtype)
2723+
if not is_keras_tensor(epsilon):
2724+
epsilon = backend.convert_to_tensor(epsilon, dtype=x.dtype)
2725+
2726+
rrms = backend.math.rsqrt(
2727+
backend.numpy.mean(backend.numpy.square(x), axis=axis, keepdims=True)
2728+
+ epsilon
2729+
)
2730+
return (x * rrms) * scale
2731+
2732+
26562733
class Polar(Operation):
26572734
def __init__(self):
26582735
super().__init__()

keras/src/ops/nn_test.py

+6
Original file line numberDiff line numberDiff line change
@@ -3142,3 +3142,9 @@ def test_invalid_strategy_ctc_decode(self):
31423142
beam_width=beam_width,
31433143
top_paths=top_paths,
31443144
)
3145+
3146+
def test_rms_normalization(self):
3147+
x = KerasTensor([None, 2, 3])
3148+
self.assertEqual(
3149+
knn.rms_normalization(x, (None, 2, 3)).shape, (None, 2, 3)
3150+
)

0 commit comments

Comments
 (0)