-
Notifications
You must be signed in to change notification settings - Fork 276
/
Copy pathtest_norm.py
202 lines (158 loc) · 6.74 KB
/
test_norm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""
Copyright (c) 2024 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import pytest
import torch
import flashinfer
from flashinfer.utils import get_compute_capability
def llama_rms_norm(x, w, eps=1e-6):
orig_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * w.float()
x = x.to(orig_dtype)
return x
def gemma_rms_norm(x, w, eps=1e-6):
orig_dtype = x.dtype
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * (1.0 + w.float())
x = x.to(orig_dtype)
return x
def gemma_fused_add_rms_norm(x, residual, w, eps=1e-6):
orig_dtype = x.dtype
x = x + residual
residual = x
x = x.float()
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * (1.0 + w.float())
x = x.to(orig_dtype)
return x, residual
def fused_add_rms_norm(x, residual, weight, eps):
orig_dtype = x.dtype
x = x.to(torch.float32)
x = x + residual.to(torch.float32)
residual = x.to(orig_dtype)
variance = x.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = (x * weight.float()).to(orig_dtype)
return x, residual
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("specify_out", [True, False])
@pytest.mark.parametrize("enable_pdl", [True, False])
@pytest.mark.parametrize("contiguous", [True, False])
def test_norm(batch_size, hidden_size, dtype, specify_out, enable_pdl, contiguous):
if contiguous:
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
else:
x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype)
x = x[:, :hidden_size]
major, _ = get_compute_capability(x.device)
if major < 9 and enable_pdl:
pytest.skip("PDL is only available for Hopper and later GPUs")
w = torch.randn(hidden_size).to(0).to(dtype)
y_ref = llama_rms_norm(x, w)
if specify_out:
y = torch.empty_like(x)
flashinfer.norm.rmsnorm(x, w, out=y, enable_pdl=enable_pdl)
else:
y = flashinfer.norm.rmsnorm(x, w, enable_pdl=enable_pdl)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("enable_pdl", [True, False])
@pytest.mark.parametrize("contiguous", [True, False])
def test_fused_add_rmsnorm(batch_size, hidden_size, dtype, enable_pdl, contiguous):
eps = 1e-6
if contiguous:
x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
else:
x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype)
x = x[:, :hidden_size]
major, _ = get_compute_capability(x.device)
if major < 9 and enable_pdl:
pytest.skip("PDL is only available for Hopper and later GPUs")
residual = torch.randn_like(x)
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
x_native, residual_native = fused_add_rms_norm(
x.clone(), residual.clone(), weight, eps
)
x_fused = x.clone()
residual_fused = residual.clone()
flashinfer.fused_add_rmsnorm(
x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl
)
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("specify_out", [True, False])
@pytest.mark.parametrize("enable_pdl", [True, False])
@pytest.mark.parametrize("contiguous", [True, False])
def test_gemma_norm(
batch_size, hidden_size, dtype, specify_out, enable_pdl, contiguous
):
if contiguous:
x = torch.randn(batch_size, hidden_size).to(0).to(dtype)
else:
x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype)
x = x[:, :hidden_size]
major, _ = get_compute_capability(x.device)
if major < 9 and enable_pdl:
pytest.skip("PDL is only available for Hopper and later GPUs")
w = torch.randn(hidden_size).to(0).to(dtype)
y_ref = gemma_rms_norm(x, w)
if specify_out:
y = torch.empty_like(x)
flashinfer.norm.gemma_rmsnorm(x, w, out=y, enable_pdl=enable_pdl)
else:
y = flashinfer.norm.gemma_rmsnorm(x, w, enable_pdl=enable_pdl)
torch.testing.assert_close(y_ref, y, rtol=1e-3, atol=1e-3)
@pytest.mark.parametrize("batch_size", [1, 19, 99, 989])
@pytest.mark.parametrize("hidden_size", [111, 500, 1024, 3072, 3584, 4096, 8192, 16384])
@pytest.mark.parametrize("dtype", [torch.float16])
@pytest.mark.parametrize("enable_pdl", [True, False])
@pytest.mark.parametrize("contiguous", [True, False])
def test_gemma_fused_add_rmsnorm(
batch_size, hidden_size, dtype, enable_pdl, contiguous
):
eps = 1e-6
if contiguous:
x = torch.randn(batch_size, hidden_size, dtype=dtype, device="cuda")
else:
x = torch.randn(batch_size, hidden_size * 2, device="cuda").to(dtype)
x = x[:, :hidden_size]
major, _ = get_compute_capability(x.device)
if major < 9 and enable_pdl:
pytest.skip("PDL is only available for Hopper and later GPUs")
residual = torch.randn_like(x)
weight = torch.randn(hidden_size, dtype=dtype, device="cuda")
x_native, residual_native = gemma_fused_add_rms_norm(
x.clone(), residual.clone(), weight, eps
)
x_fused = x.clone()
residual_fused = residual.clone()
flashinfer.gemma_fused_add_rmsnorm(
x_fused, residual_fused, weight, eps, enable_pdl=enable_pdl
)
torch.testing.assert_close(x_fused, x_native, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(residual_fused, residual_native, rtol=1e-3, atol=1e-3)
if __name__ == "__main__":
# test_norm(1, 1024, torch.float16, False, True)
test_fused_add_rmsnorm(1, 16384, torch.float16, True, True)