Skip to content

Commit e6b7eae

Browse files
authored
Improve RMSNorm to support 2D inputs (vllm-project#784)
When the input is 2D, we unsqueeze it to 3D to meet HPUFusedRMSNorm requirements
1 parent 7a16eb9 commit e6b7eae

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

vllm/model_executor/layers/layernorm.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,11 +113,18 @@ def forward_hpu(
113113
orig_shape = x.shape
114114
residual += x.view(residual.shape)
115115
# Note: HPUFusedRMSNorm requires 3D tensors as inputs
116+
residual_shape = residual.shape
117+
if len(residual_shape) == 2:
118+
residual = residual.unsqueeze(0)
116119
x = HPUFusedRMSNorm.apply(residual, self.weight,
117120
self.variance_epsilon)
118-
return x.view(orig_shape), residual
121+
return x.view(orig_shape), residual.view(residual_shape)
119122

123+
orig_shape = x.shape
124+
if len(orig_shape) == 2:
125+
x = x.unsqueeze(0)
120126
x = HPUFusedRMSNorm.apply(x, self.weight, self.variance_epsilon)
127+
x = x.view(orig_shape)
121128
return x
122129

123130
def forward_xpu(

0 commit comments

Comments
 (0)