Skip to content

Commit 6d40dc3

Browse files
WoosukKwonrandxie
authored andcommitted
Implement approximate GELU kernels (vllm-project#828)
1 parent 1f6a239 commit 6d40dc3

File tree

4 files changed

+164
-18
lines changed

4 files changed

+164
-18
lines changed

csrc/activation.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,25 @@ void silu_and_mul(
44
torch::Tensor& out,
55
torch::Tensor& input);
66

7+
void gelu_new(
8+
torch::Tensor& out,
9+
torch::Tensor& input);
10+
11+
void gelu_fast(
12+
torch::Tensor& out,
13+
torch::Tensor& input);
14+
715
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
816
m.def(
917
"silu_and_mul",
1018
&silu_and_mul,
1119
"Activation function used in SwiGLU.");
20+
m.def(
21+
"gelu_new",
22+
&gelu_new,
23+
"GELU implementation used in GPT-2.");
24+
m.def(
25+
"gelu_fast",
26+
&gelu_fast,
27+
"Approximate GELU implementation.");
1228
}

csrc/activation_kernels.cu

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,71 @@ void silu_and_mul(
4646
d);
4747
});
4848
}
49+
50+
namespace vllm {
51+
52+
// Element-wise activation kernel template.
53+
template<typename scalar_t, scalar_t (*ACT_FN)(const scalar_t&)>
54+
__global__ void activation_kernel(
55+
scalar_t* __restrict__ out, // [num_tokens, d]
56+
const scalar_t* __restrict__ input, // [num_tokens, d]
57+
const int d) {
58+
const int token_idx = blockIdx.x;
59+
for (int idx = threadIdx.x; idx < d; idx += blockDim.x) {
60+
const scalar_t x = __ldg(&input[token_idx * d + idx]);
61+
out[token_idx * d + idx] = ACT_FN(x);
62+
}
63+
}
64+
65+
} // namespace vllm
66+
67+
// Launch element-wise activation kernel.
68+
#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \
69+
int num_tokens = input.size(0); \
70+
int d = input.size(1); \
71+
dim3 grid(num_tokens); \
72+
dim3 block(std::min(d, 1024)); \
73+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \
74+
AT_DISPATCH_FLOATING_TYPES_AND2( \
75+
at::ScalarType::Half, \
76+
at::ScalarType::BFloat16, \
77+
input.scalar_type(), \
78+
"activation_kernel", \
79+
[&] { \
80+
vllm::activation_kernel<scalar_t, KERNEL<scalar_t>><<<grid, block, 0, stream>>>( \
81+
out.data_ptr<scalar_t>(), \
82+
input.data_ptr<scalar_t>(), \
83+
d); \
84+
});
85+
86+
namespace vllm {
87+
88+
template<typename T>
89+
__device__ __forceinline__ T gelu_new_kernel(const T& x) {
90+
const float x3 = (float) (x * x * x);
91+
const T t = (T) tanhf((T) (0.79788456f * (float) (x + (T) (0.044715f * x3))));
92+
return ((T) 0.5) * x * (((T) 1.0) + t);
93+
}
94+
95+
template<typename T>
96+
__device__ __forceinline__ T gelu_fast_kernel(const T& x) {
97+
const float f = (float) x;
98+
const T t = (T) tanhf(((T) (f * 0.79788456f)) * (((T) 1.0) + (T) (0.044715f * f) * x));
99+
return ((T) 0.5) * x * (((T) 1.0) + t);
100+
}
101+
102+
} // namespace vllm
103+
104+
void gelu_new(
105+
torch::Tensor& out, // [num_tokens, d]
106+
torch::Tensor& input) // [num_tokens, d]
107+
{
108+
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel);
109+
}
110+
111+
void gelu_fast(
112+
torch::Tensor& out, // [num_tokens, d]
113+
torch::Tensor& input) // [num_tokens, d]
114+
{
115+
LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel);
116+
}

tests/kernels/test_activation.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import torch.nn.functional as F
3-
3+
from transformers.activations import get_activation
44
from vllm import activation_ops
55

66

@@ -28,3 +28,45 @@ def test_silu_and_mul() -> None:
2828
for d in [512, 4096, 5120, 13824]:
2929
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
3030
run_silu_and_mul(num_tokens, d, dtype)
31+
32+
33+
@torch.inference_mode()
34+
def run_gelu_new(
35+
num_tokens: int,
36+
d: int,
37+
dtype: torch.dtype,
38+
) -> None:
39+
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
40+
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
41+
activation_ops.gelu_new(out, x)
42+
ref_out = get_activation("gelu_new")(x)
43+
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
44+
45+
46+
def test_gelu_new() -> None:
47+
for dtype in [torch.half, torch.bfloat16, torch.float]:
48+
for num_tokens in [7, 83, 2048]:
49+
for d in [512, 4096, 5120, 13824]:
50+
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
51+
run_gelu_new(num_tokens, d, dtype)
52+
53+
54+
@torch.inference_mode()
55+
def run_gelu_fast(
56+
num_tokens: int,
57+
d: int,
58+
dtype: torch.dtype,
59+
) -> None:
60+
x = torch.randn(num_tokens, d, dtype=dtype, device='cuda')
61+
out = torch.empty(num_tokens, d, dtype=dtype, device='cuda')
62+
activation_ops.gelu_fast(out, x)
63+
ref_out = get_activation("gelu_fast")(x)
64+
assert torch.allclose(out, ref_out, atol=1e-5, rtol=1e-5)
65+
66+
67+
def test_gelu_fast() -> None:
68+
for dtype in [torch.half, torch.bfloat16, torch.float]:
69+
for num_tokens in [7, 83, 2048]:
70+
for d in [512, 4096, 5120, 13824]:
71+
print(f'Testing dtype={dtype}, num_tokens={num_tokens}, d={d}')
72+
run_gelu_fast(num_tokens, d, dtype)

vllm/model_executor/layers/activation.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,6 @@
44

55
from vllm import activation_ops
66

7-
_ACTIVATION_REGISTRY = {
8-
"gelu": nn.GELU(),
9-
# NOTE: The following GELU functions may introduce small rounding errors.
10-
"gelu_new": nn.GELU(approximate="tanh"),
11-
"gelu_fast": nn.GELU(approximate="tanh"),
12-
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
13-
"relu": nn.ReLU(),
14-
}
15-
16-
17-
def get_act_fn(act_fn: str) -> nn.Module:
18-
"""Get an activation function by name."""
19-
act_fn = act_fn.lower()
20-
if act_fn in _ACTIVATION_REGISTRY:
21-
return _ACTIVATION_REGISTRY[act_fn]
22-
raise ValueError(f"Activation function {act_fn!r} is not supported.")
23-
247

258
class SiluAndMul(nn.Module):
269
"""An activation function for SwiGLU.
@@ -38,3 +21,40 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
3821
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
3922
activation_ops.silu_and_mul(out, x)
4023
return out
24+
25+
26+
class NewGELU(nn.Module):
27+
28+
def forward(self, x: torch.Tensor) -> torch.Tensor:
29+
num_tokens = x.shape[0]
30+
d = x.shape[1]
31+
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
32+
activation_ops.gelu_new(out, x)
33+
return out
34+
35+
36+
class FastGELU(nn.Module):
37+
38+
def forward(self, x: torch.Tensor) -> torch.Tensor:
39+
num_tokens = x.shape[0]
40+
d = x.shape[1]
41+
out = torch.empty(num_tokens, d, dtype=x.dtype, device=x.device)
42+
activation_ops.gelu_fast(out, x)
43+
return out
44+
45+
46+
_ACTIVATION_REGISTRY = {
47+
"gelu": nn.GELU(),
48+
"gelu_fast": FastGELU(),
49+
"gelu_new": NewGELU(),
50+
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
51+
"relu": nn.ReLU(),
52+
}
53+
54+
55+
def get_act_fn(act_fn: str) -> nn.Module:
56+
"""Get an activation function by name."""
57+
act_fn = act_fn.lower()
58+
if act_fn in _ACTIVATION_REGISTRY:
59+
return _ACTIVATION_REGISTRY[act_fn]
60+
raise ValueError(f"Activation function {act_fn!r} is not supported.")

0 commit comments

Comments
 (0)