Skip to content

add causal-conv1d in Triton and integrate into vLLM with test code #18218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 32 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ad83738
add causal-conv1d in Triton and integrate into vLLM with test code
tmhoangt May 15, 2025
f4c56bf
add causal-conv1d in Triton and integrate into vLLM with test code
tmhoangt May 15, 2025
dfa7159
resolve merge conflict
tmhoangt May 15, 2025
61d7ed9
fix a bug when migrating code to vLLM
tmhoangt May 15, 2025
8882cef
fix a bug when migrating code to vLLM
tmhoangt May 15, 2025
775e561
refactor for code style
tmhoangt May 16, 2025
939a823
refactor for code style
tmhoangt May 16, 2025
29b7941
refactor for code style
tmhoangt May 16, 2025
7bfe0e8
refactor for code style
tmhoangt May 16, 2025
52d601c
refactor for code style
tmhoangt May 16, 2025
081a8be
Update tests/kernels/mamba/test_causal_conv1d.py
thoangtrvn Jun 2, 2025
9eb1cc3
update test code to cover more use-cases
tmhoangt Jun 2, 2025
091b31e
refactor code based on feedback
tmhoangt Jun 4, 2025
bfabaae
refactor code based on feedback
tmhoangt Jun 4, 2025
da660f0
refactor code based on feedback
tmhoangt Jun 4, 2025
7af7f58
refactor code based on feedback
tmhoangt Jun 4, 2025
10e332c
Merge branch 'main' into pr_conv1d_triton
thoangtrvn Jun 4, 2025
ecb3a2c
refactor code based on feedback
tmhoangt Jun 4, 2025
107911a
refactor code based on feedback
tmhoangt Jun 4, 2025
bfc2f28
refactor code to fix mypy codecheck
tmhoangt Jun 4, 2025
ef21b3d
refactor code to fix mypy codecheck
tmhoangt Jun 4, 2025
400e669
Merge branch 'pr_conv1d_triton' of github.com:thoangtrvn/vllm into pr…
tmhoangt Jun 4, 2025
f0be762
refactor code to fix mypy codecheck
tmhoangt Jun 4, 2025
4cfb12d
revert code change based on feedback
tmhoangt Jun 5, 2025
64ee33d
revert code change based on feedback
tmhoangt Jun 5, 2025
19586c5
revert code change based on feedback
tmhoangt Jun 5, 2025
e3192e8
migrate code change based on feedback
tmhoangt Jun 5, 2025
8aad208
migrate code change based on feedback
tmhoangt Jun 5, 2025
a0d2170
revert code change based on feedback
tmhoangt Jun 5, 2025
4d1bb63
revert code change based on feedback
tmhoangt Jun 5, 2025
679eb1c
migrate code change based on feedback
tmhoangt Jun 5, 2025
c782f25
fix merge conflict from upstream/main
tmhoangt Jun 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
221 changes: 220 additions & 1 deletion tests/kernels/mamba/test_causal_conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import pytest
import torch
import torch.nn.functional as F
from einops import rearrange

from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops # noqa: F401
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
causal_conv1d_fn, causal_conv1d_fn_triton, causal_conv1d_update,
causal_conv1d_update_triton)
from vllm.platforms import current_platform


Expand Down Expand Up @@ -436,3 +438,220 @@ def test_causal_conv1d_varlen(with_padding, dim, seqlen, width, has_bias,
causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(),
padded_state_indices, has_initial_states,
final_states, activation)


@pytest.mark.parametrize("itype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1, 3, 5])
@pytest.mark.parametrize("width", [2, 3, 4])
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
# tests correctness in case subset of the sequences are padded
@pytest.mark.parametrize("with_padding", [True, False])
@pytest.mark.parametrize("batch_size", [3])
def test_causal_conv1d_update_with_batch_gather_vllm(batch_size, with_padding,
dim, width, seqlen,
has_bias, silu_activation,
itype):
device = "cuda"
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2

# set seed
current_platform.seed_everything(0)

padding = 5 if with_padding else 0
padded_batch_size = batch_size + padding
# total_entries = number of cache line
total_entries = 10 * batch_size

channel_last = True
if not channel_last:
x = torch.randn(padded_batch_size,
dim,
seqlen,
device=device,
dtype=itype)
else:
# x will be (batch, dim, seqlen) with contiguous along dim-axis
x = torch.randn(padded_batch_size,
seqlen,
dim,
device=device,
dtype=itype).transpose(1, 2)

x_ref = x.clone()

conv_state_indices = torch.randperm(total_entries)[:batch_size].to(
dtype=torch.int32, device=device)
unused_states_bool = torch.ones(total_entries,
dtype=torch.bool,
device=device)
unused_states_bool[conv_state_indices] = False
padded_state_indices = torch.concat([
conv_state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device)
],
dim=0)

if not channel_last:
conv_state = torch.randn(total_entries,
dim,
width - 1,
device=device,
dtype=itype)
else:
# conv_state will be (cache_lines, dim, state_len)
# with contiguous along dim-axis
conv_state = torch.randn(total_entries,
width - 1,
dim,
device=device,
dtype=itype).transpose(1, 2)

conv_state_for_padding_test = conv_state.clone()

weight = torch.randn(dim, width, device=device, dtype=itype)
bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
activation = None if not silu_activation else "silu"

out = causal_conv1d_update_triton(x,
conv_state,
weight,
bias,
activation=activation,
conv_state_indices=padded_state_indices,
pad_slot_id=PAD_SLOT_ID)
out_ref = causal_conv1d_update_ref(x_ref[:batch_size],
conv_state_ref,
weight,
bias,
activation=activation)

assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
assert torch.equal(conv_state[unused_states_bool],
conv_state_for_padding_test[unused_states_bool])
assert torch.allclose(out[:batch_size], out_ref, rtol=rtol, atol=atol)


@pytest.mark.parametrize("itype", [torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [True])
@pytest.mark.parametrize("has_bias", [True])
@pytest.mark.parametrize("width", [4])
@pytest.mark.parametrize('seqlen', [8, 16, 784, 1024, 2048, 2049, 4096])
@pytest.mark.parametrize('dim', [64, 4096])
@pytest.mark.parametrize('with_padding', [True, False])
@pytest.mark.parametrize('batch', [4, 8, 10])
def test_causal_conv1d_varlen_vllm(batch, with_padding, dim, seqlen, width,
has_bias, silu_activation, itype):
device = "cuda"
torch.cuda.empty_cache()
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
if itype == torch.bfloat16:
rtol, atol = 1e-2, 5e-2
# set seed
current_platform.seed_everything(0)
seqlens = []
batch_size = batch
padding = 3 if with_padding else 0
padded_batch_size = batch_size + padding
nsplits = padded_batch_size - 1

eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values

seqlens.append(
torch.diff(
torch.cat(
[torch.tensor([-1]), eos_pos,
torch.tensor([seqlen - 1])])).tolist())
assert sum(seqlens[-1]) == seqlen
assert all(s > 0 for s in seqlens[-1])

total_entries = batch_size * 10
cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32)
cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum],
dim=0)
channel_last = True
if not channel_last:
x = torch.randn(1, 4096 + dim + 64, seqlen, device=device,
dtype=itype)[:, 4096:4096 + dim, :]
else:
x = rearrange(
torch.randn(1, seqlen, 4096 + dim + 64, device=device,
dtype=itype), "b s d -> b d s")[:, 4096:4096 + dim, :]

weight = torch.randn(dim, width, device=device, dtype=itype)

bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None
x_ref = x.clone()
weight_ref = weight.clone()
bias_ref = bias.clone() if bias is not None else None
activation = None if not silu_activation else "silu"
if not channel_last:
final_states = torch.randn(total_entries,
dim,
width - 1,
device=x.device,
dtype=x.dtype)
else:
final_states = torch.randn(total_entries,
width - 1,
dim,
device=x.device,
dtype=x.dtype).transpose(1, 2)
final_states_ref = final_states.clone()
has_initial_states = torch.randint(0,
2, (cumsum.shape[0] - 1, ),
dtype=torch.bool,
device=x.device)
state_indices = torch.randperm(total_entries,
dtype=torch.int32,
device=x.device)[:batch_size]
padded_state_indices = torch.concat([
state_indices,
torch.as_tensor(
[PAD_SLOT_ID] * padding, dtype=torch.int32, device=device),
],
dim=-1)
out = causal_conv1d_fn_triton(x.squeeze(0),
weight,
bias=bias,
conv_states=final_states,
query_start_loc=cumsum.cuda(),
cache_indices=padded_state_indices,
has_initial_states=has_initial_states,
activation=activation,
pad_slot_id=PAD_SLOT_ID)

out_ref = []
out_ref_b = []

splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)]
for i in range(len(seqlens[0])):
x_s = [v[i].unsqueeze(0) for v in splits][0]
if padded_state_indices[i] == PAD_SLOT_ID:
continue
out_ref_b.append(
causal_conv1d_ref(
x_s,
weight_ref,
bias_ref,
activation=activation,
return_final_states=True,
final_states_out=final_states_ref[
padded_state_indices[i]].unsqueeze(0),
initial_states=final_states_ref[padded_state_indices[i]].
unsqueeze(0) if has_initial_states[i] else None))
out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2))
out_ref_tensor = torch.cat(out_ref, dim=0)

assert torch.allclose(final_states[state_indices],
final_states_ref[state_indices],
rtol=rtol,
atol=atol)
unpadded_out = out[:, :out_ref_tensor.shape[-1]]
assert torch.allclose(unpadded_out, out_ref_tensor, rtol=rtol, atol=atol)
5 changes: 5 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
VLLM_ALL2ALL_BACKEND: str = "naive"
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
VLLM_USE_TRITON_CONV1D: bool = False
VLLM_SLEEP_WHEN_IDLE: bool = False


Expand Down Expand Up @@ -832,6 +833,10 @@ def get_vllm_port() -> Optional[int]:
"VLLM_ALL2ALL_BACKEND":
lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"),

# use Triton implementation of causal-conv1d
"VLLM_USE_TRITON_CONV1D":
lambda: os.getenv("VLLM_USE_TRITON_CONV1D", "0") == "1",

# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
# the blockscale tensor of activations NVFP4 Quantization.
Expand Down
Loading