|
10 | 10 | from keras.src.backend.common.backend_utils import (
|
11 | 11 | compute_conv_transpose_output_shape,
|
12 | 12 | )
|
| 13 | +from keras.src.backend.common.keras_tensor import is_keras_tensor |
13 | 14 | from keras.src.ops import operation_utils
|
14 | 15 | from keras.src.ops.operation import Operation
|
15 | 16 | from keras.src.ops.operation_utils import reduce_shape
|
@@ -2653,6 +2654,82 @@ def dot_product_attention(
|
2653 | 2654 | )
|
2654 | 2655 |
|
2655 | 2656 |
|
| 2657 | +class RMSNorm(Operation): |
| 2658 | + def __init__(self, scale, axis=-1, epsilon=None): |
| 2659 | + super().__init__() |
| 2660 | + self.axis = axis |
| 2661 | + self.scale = scale |
| 2662 | + self.epsilon = epsilon |
| 2663 | + |
| 2664 | + def compute_output_spec(self, x): |
| 2665 | + return KerasTensor(shape=x.shape) |
| 2666 | + |
| 2667 | + def call(self, x): |
| 2668 | + return _rms_normalization( |
| 2669 | + x, scale=self.scale, axis=self.axis, epsilon=self.epsilon |
| 2670 | + ) |
| 2671 | + |
| 2672 | + |
| 2673 | +@keras_export( |
| 2674 | + [ |
| 2675 | + "keras.ops.rms_normalization", |
| 2676 | + "keras.ops.nn.rms_normalization", |
| 2677 | + ] |
| 2678 | +) |
| 2679 | +def rms_normalization(x, scale=1, axis=-1, epsilon=None): |
| 2680 | + """Performs Root Mean Square (RMS) normalization on `x`. |
| 2681 | +
|
| 2682 | + The Keras operation implements the operation as described in |
| 2683 | + [Root Mean Square Layer Normalization](https://arxiv.org/pdf/1910.07467) |
| 2684 | + by Biao Zhang et al. |
| 2685 | +
|
| 2686 | + The operation is different from LayerNormalization with RMS scaling. |
| 2687 | +
|
| 2688 | + It is defined as `rms_normalization(x) = x * rsqrt(mean(square(x))) * scale` |
| 2689 | +
|
| 2690 | + Args: |
| 2691 | + x: Input tensor. |
| 2692 | + axis: The axis or axes along which to perform normalization. |
| 2693 | + Default to -1. |
| 2694 | + scale: Optional scaling factor for the normalization. |
| 2695 | + epsilon: A lower bound value for the norm. |
| 2696 | + Defaults to `backend.epsilon()`. |
| 2697 | +
|
| 2698 | + Returns: |
| 2699 | + The normalized array. |
| 2700 | +
|
| 2701 | + Example: |
| 2702 | +
|
| 2703 | + >>> x = np.random.rand(1, 10) |
| 2704 | + >>> x_norm = keras.ops.rms_normalization(x, (10,)) |
| 2705 | + >>> print(x_norm) |
| 2706 | + array([[0.69384296, 0.94444374, 0.16551171, 0.05749961, 1.11008865, |
| 2707 | + 0.52475186, 1.57686807, 1.69893307, 1.27292764, 0.30819128]]) |
| 2708 | + """ |
| 2709 | + if any_symbolic_tensors((x,)): |
| 2710 | + return RMSNorm(scale=scale, axis=axis, epsilon=epsilon).symbolic_call(x) |
| 2711 | + return _rms_normalization(x, scale=scale, axis=axis, epsilon=epsilon) |
| 2712 | + |
| 2713 | + |
| 2714 | +def _rms_normalization(x, scale=1, axis=-1, epsilon=None): |
| 2715 | + x = backend.convert_to_tensor(x) |
| 2716 | + if len(x.shape) == 0: |
| 2717 | + x = backend.numpy.expand_dims(x, axis=0) |
| 2718 | + if epsilon is None: |
| 2719 | + epsilon = backend.epsilon() |
| 2720 | + |
| 2721 | + if not is_keras_tensor(scale): |
| 2722 | + scale = backend.convert_to_tensor(scale, dtype=x.dtype) |
| 2723 | + if not is_keras_tensor(epsilon): |
| 2724 | + epsilon = backend.convert_to_tensor(epsilon, dtype=x.dtype) |
| 2725 | + |
| 2726 | + rrms = backend.math.rsqrt( |
| 2727 | + backend.numpy.mean(backend.numpy.square(x), axis=axis, keepdims=True) |
| 2728 | + + epsilon |
| 2729 | + ) |
| 2730 | + return (x * rrms) * scale |
| 2731 | + |
| 2732 | + |
2656 | 2733 | class Polar(Operation):
|
2657 | 2734 | def __init__(self):
|
2658 | 2735 | super().__init__()
|
|
0 commit comments