Skip to content

add causal-conv1d in Triton and integrate into vLLM with test code #18218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

thoangtrvn
Copy link

@thoangtrvn thoangtrvn commented May 15, 2025

This PR adds Triton-based causal-conv1d, making Mamba-based models in vLLM

  1. fully Triton-only backend.
  2. one step closer to be compatible with vLLM v1 design, i.e. without splitting the batch into prefill-only and decode-only for CUDA-split processing.

There are two kernels implemented

  • causal_conv1d_update_triton: which outperforms the corresponding CUDA kernel in handling decode-only requests
image [data benchmarking two kernels runtime by increasing the number of decode-only requests in a batch]
  • causal_conv1d_fn_triton: which outperform CUDA kernel in batch of mixed prefill/decode requests, e.g. 27x faster in the below microbenchmark with the same batch of mixed prefill/decode requests.
image

It also performs better than the CUDA-split pathway which was merged as PR #17146.
image
[data benchmarking runtime processing the same batch of mixed requests, first send the batch to the single Triton kernel, and then using CUDA-split pathway where requests are first separated, with prefill-only requests are sent to one kernel, and decode-only requests are sent to the second kernel]

ALGORITHMIC CHOICE OF TRITON KERNEL: Unlike CUDA kernel which is implemented with parallelism in 2D, i.e. along feature-dimension, and batch size; Triton kernel is implemented with parallelism in 3D, i.e. along also sequence-dimension. Also, the Triton kernels don't make any changes to the layout of the input data which is contiguous along the feature-dimension. Another key difference is that Triton kernels expect the conv-state to be contiguous along the feature-dimension, while in existing CUDA implementation, it expects the conv-state cache to be contiguous along the kernel-width (i.e. sequence-length) axis. Nevertheless, the two CUDA kernels are not compatible with the layout of conv-state cache, and therefore prevents the efficient processing in decode-only requests or mixed prefill/decode-requests.

Also, some other improvement in reducing overhead is incorporated.
Even though binary code generated from Triton is faster, the launch overhead is a known issue and is therefore need further optimization to get the E2E Triton-only Mamba models in vLLM performant. Here, we also incorporate such improvements by using a metadata that can be reused across layers.

In our benchmark on ShareGPT dataset which has short input prompt (a few hundreds of tokens)

  • default setting: generates short number of tokens (i.e. 256 tokens) CUDA-backed Bamba (ibm-ai-platform/Bamba-9B) is still faster; 10% slower (total token throughput) yet only 2% in output token throughput and 2% in TTFT.

  • generating 1024 tokens: Triton-backed Bamba is now faster with 5% faster on token throughput; and with 11% faster on TTFT. The benefit of faster Triton kernels now exceeds the overall costs of Triton launch overhead.

image

In the longer context length and/or longer number of generated tokens, Triton-only Mamba-based model is expected to be better than CUDA-split approach. However, the PR maintains the existing CUDA pathway as the default one until it is adopted by vLLM maintainers. Currently, the code is added as an optional pathway to the CUDA-split via VLLM_USE_TRITON_CONV1D environment variable set to 1.

This is one step closer to be compatible with vLLM v1 design, i.e. without splitting the batch into prefill-only and decode-only for CUDA-split processing.

Test code is also added

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

tmhoangt added 7 commits May 15, 2025 17:52
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
Copy link

mergify bot commented May 27, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @thoangtrvn.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label May 27, 2025
Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you report some lm_eval scores running gsm8k, as well as make sure it runs correctly without --enforce-eager?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the changes to bamba.py, granitemoehybrid.py, mamba2.py and zamba2.py are pretty spurious. Could you revert those?

Copy link
Author

@thoangtrvn thoangtrvn Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the Mamba2's metadata, there are setting that get updated once (at model initialization), and there are settings that get updated at every input (while keeping the same across layers).
Adding self.mamba2_metadata provides a solution to reuse updated-once data. If you don't like this level of optimization, please let me know @tlrmchlsmth. This is optional. I can revert the changes to the other models, and keep the change only on bamba.py.

Comment on lines 39 to 40
path = os.environ.get("VLLM_USE_TRITON_CONV1D", None)
if path is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we add an environment variable, this should be handled in vllm/envs.py.

However, generally we should avoid adding an environment variable where possible. Which cases should we be using the triton conv1d kernels vs CUDA?

Copy link
Author

@thoangtrvn thoangtrvn Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far, the testing shows CUDA lower overhead is better at short context length (~300-400 tokens, and short output length); other than that, Triton-kernel is better. The e2e slower doesn't come from the kernel itself, but the overhead launch in overall. I just want to maintain the two pathways before a final decision is made, based on other testing outside what I have done.

I'll provide more details of test cases in the PR threads.

While in theory, some Triton launch overhead can be reduced using Triton JIT cache mechanism, it is not tested here.

[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("silu_activation", [False, True])
@pytest.mark.parametrize("has_bias", [False, True])
@pytest.mark.parametrize("seqlen", [1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we test more than just seqlen 1?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes, I can add more to the test code

Comment on lines 143 to 148
"""Create a weight loader for mamba v2. This ensures that the projections
are correctly sharded so that they can be split into x, B, C. It also
ensures that all the groups corresponding to a head shard is placed
"""Create a weight loader for mamba v2. This ensures that the projections
are correctly sharded so that they can be split into x, B, C. It also
ensures the the all the groups corresponding to a head shard is placed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: please revert the the back to that

Comment on lines +18 to +23
batch_ptr = torch.full(
(MAX_NUM_PROGRAMS, ), PAD_SLOT_ID, dtype=torch.int32,
device='cpu') # tracking which seq-idx the Triton program is handling
token_chunk_offset_ptr = torch.full(
(MAX_NUM_PROGRAMS, ), PAD_SLOT_ID, dtype=torch.int32, device='cpu'
) # tracking BLOCK_M-based index in the sequence the Triton program is handling
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these stateful global variables? Would it be better for these to go in the mamba2_metadata instead?

Copy link
Author

@thoangtrvn thoangtrvn Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Explain: the parallelism of the causal-conv1d kernel (prefill stage) is 3D: batch, feature, and seqlen dimensions. It means that each Triton program can handle a group of token in a sequence. The information about token range (start/stop) at each sequence, is tracked by these two tensors. In vLLM, the tensors are generated as part of the metadata construct, which means at e2e runtime, it doesn't use these.
However, for kernel-level runtime, e.g. microbenchmarking or kernel testing; they needs to be provided. Here, we have 3 choices:

  • created a metadata object just like vLLM e2e setting
  • when metadata is not provided, either we declare inside the function that use it (which create the tensor each time a kernel get invoked), or declare at the module-level (created once at module loading - safe some overhead of tensor allocation).

Please let me know what you think it should be revised @tlrmchlsmth .

Comment on lines +27 to +50
num_cache_lines: Optional[int] = None
stride_istate_seq: Optional[int] = None
stride_istate_dim: Optional[int] = None
stride_istate_token: Optional[int] = None
seqlens: Optional[np.ndarray] = None
padded_batch: Optional[int] = None
nums_dict: Optional[dict] = None
is_channel_last: bool = True
stride_w_dim: Optional[int] = None
stride_w_width: Optional[int] = None
width: Optional[int] = None
np2_statelen: Optional[int] = None
stride_x_seq: Optional[int] = 0
stride_x_dim: Optional[int] = None
stride_x_token: Optional[int] = None
dim: Optional[int] = None
cu_seqlen: Optional[int] = None
out: Optional[torch.Tensor] = None
stride_o_seq: Optional[int] = 0
stride_o_dim: Optional[int] = None
stride_o_token: Optional[int] = None
MAX_NUM_PROGRAMS: int = 1024
batch_ptr: Optional[torch.tensor] = None
token_chunk_offset_ptr: Optional[torch.tensor] = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a lot of stuff in here, and it's not really clear what most of it is for. At first glance it seems like most of this should be accessed on the fly instead of stored in the metadata here. Could you take a stab at cleaning this up?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The information here is reused across Mamba layers, even the stride call would trigger Torch calls, which triggers an unnecessary overhead. I can adds a description as needed. Please let me know what you think @tlrmchlsmth

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think most of this should be removed. For CPU overheads in decode, we can rely on CUDA graphs and for prefill they are amortized

Comment on lines 110 to 115
# NOTE: currently it is assumed prefill requests come before decode requests -> we can use ':num_prefills' slicing
# TODO: maybe revert back to the original code (below) if above no longer holds
# has_initial_states = attn_metadata.context_lens_tensor > 0
# zero_init_indices = mamba_cache_params.state_indices_tensor[~has_initial_states]
# mamba_cache_params.ssm_state[zero_init_indices] = 0
# initial_states = mamba_cache_params.ssm_state[mamba_cache_params.state_indices_tensor]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can rely on batch reordering and require that it be used for this implementation.

Copy link
Author

@thoangtrvn thoangtrvn Jun 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change from a previous PR (cuda-split) implies this assumption, this PR doesn't have this assumption which makes is more suitable for vLLM v1 design, the comment I added here is to clarify the code path from the previous PR. I can remove the comment as needed. Please let me know @tlrmchlsmth

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, please remove the comment.

We can rely on batch reordering even in vLLM V1, so this is a non issue

Comment on lines 388 to 394
def is_conv_in_Triton(self):
import os
path = os.environ.get("VLLM_USE_TRITON_CONV1D", None)
if path is not None:
print("mamba_mixer2 - VLLM_USE_TRITON_CONV1D")
return True
return False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few things here:

  1. Remove the print before landing
  2. env variables should be defined in vllm/envs.py
  3. The function would be better-named use_triton_causal_conv_1d (more descriptive & proper capitalization)
  4. We should work hard to avoid proliferation of environment variables. Could you come up with a reliable heuristic to choose between this triton implementation and the CUDA kernel instead of adding an env that exposes complexity to the end-user?

Copy link
Author

@thoangtrvn thoangtrvn Jun 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I update the code accordingly. The Triton version use the layout of conv_state cache that is the same layout with input tensor hidden_states_B_C, i.e. contiguous along feature-dimension. The CUDA version in vLLM use a different layout of conv_state cache that is contiguous along sequence-dimension. So, there is no good way to switch from one to another in the same session. It's the choice to be made, based on the expected workload. The benefit of Mamba-based model is in long context length and long output response, which we shows that Triton version would take over. Also, the Triton kernel follows vLLM v1 design, allowing prefill requests and decode requests to be mixed.

RESULT: long-context latency measurement (compared with Llama)

python benchmarks/benchmark_latency.py --model /net/storage149/autofs/css22/nmg/models/hf/meta-llama/Llama-3.1-8B-Instruct/main --input-len=131072 --output-len=1 --batch-size=1 --max_num_batched_tokens=2048

Avg latency: 11.214325266804856 seconds
10% percentile latency: 11.202042526123114 seconds
25% percentile latency: 11.206939334078925 seconds
50% percentile latency: 11.212064623483457 seconds
75% percentile latency: 11.220630767958937 seconds
90% percentile latency: 11.2278619370074 seconds
99% percentile latency: 11.24528882814222 seconds
  • Main branch
# Default (max_num_batched_tokens=2048)
python benchmarks/benchmark_latency.py --model ibm-ai-platform/Bamba-9B-v2 --input-len=131072 --output-len=1 --batch-size=1
Avg latency: 6.231618080474436 seconds
10% percentile latency: 6.204561746446416 seconds
25% percentile latency: 6.216710253240308 seconds
50% percentile latency: 6.219352831016295 seconds
75% percentile latency: 6.223808606999228 seconds
90% percentile latency: 6.227424801979214 seconds
99% percentile latency: 6.519547982601217 seconds
  • Current PR:
export VLLM_USE_TRITON_CONV1D="1"
python benchmarks/benchmark_latency.py --model ibm-ai-platform/Bamba-9B-v2 --input-len=131072 --output-len=1 --batch-size=1

Avg latency: 5.757278195097267 seconds    
10% percentile latency: 5.734804188809358 seconds 
25% percentile latency: 5.739403567742556 seconds  
50% percentile latency: 5.743007940007374 seconds  
75% percentile latency: 5.748229099262971 seconds   
90% percentile latency: 5.751799210254103 seconds 
99% percentile latency: 6.068630096325651 seconds 

thoangtrvn and others added 3 commits June 2, 2025 12:05
@thoangtrvn
Copy link
Author

GSM8K RESULT


#ibm-ai-platform/Bamba-9B
# (current) CUDA-SPIT code
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.2335|±  |0.0117|
|     |       |strict-match    |     5|exact_match|↑  |0.3442|±  |0.0131|
# PR code
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.2456|±  |0.0119|
|     |       |strict-match    |     5|exact_match|↑  |0.3495|±  |0.0131|

#Zyphra/Zamba2-2.7B
# (current) CUDA-SPIT code
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5330|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.5466|±  |0.0137|
# PR code
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5330|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.5466|±  |0.0137|



#mistralai/Mamba-Codestral-7B-v0.1
# (current) CUDA-SPIT code
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.4647|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.4549|±  |0.0137|
# PR code
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.4655|±  |0.0137|
|     |       |strict-match    |     5|exact_match|↑  |0.4526|±  |0.0137|

COMMAND TO RUN:

echo 'ibm-ai-platform/Bamba-9B'
lm_eval --model vllm     --model_args pretrained=ibm-ai-platform/Bamba-9B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9 --batch_size auto  --cache_requests true --tasks gsm8k


echo "DONE RUN (CUDA-SPLIT)"

export VLLM_USE_TRITON_CONV1D="1"
lm_eval --model vllm     --model_args pretrained=ibm-ai-platform/Bamba-9B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9 --batch_size auto  --cache_requests true --tasks gsm8k


echo "DONE RUN (PR)"

echo 'Zyphra/Zamba2-2.7B'
lm_eval --model vllm     --model_args pretrained=Zyphra/Zamba2-2.7B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9 --batch_size auto --trust_remote_code  --cache_requests true --tasks gsm8k
echo "DONE RUN (CUDA-SPLIT)"

export VLLM_USE_TRITON_CONV1D="1"
lm_eval --model vllm     --model_args pretrained=Zyphra/Zamba2-2.7B,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9 --batch_size auto --trust_remote_code  --cache_requests true --tasks gsm8k
echo "DONE RUN (PR)"

echo 'Mamba-Codestral-7B-v0.1'
lm_eval --model vllm     --model_args pretrained=mistralai/Mamba-Codestral-7B-v0.1,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9 --batch_size auto --trust_remote_code  --cache_requests true --tasks gsm8k
echo "DONE RUN (CUDA-SPLIT)"

export VLLM_USE_TRITON_CONV1D="1"
lm_eval --model vllm     --model_args pretrained=mistralai/Mamba-Codestral-7B-v0.1,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.9 --batch_size auto --trust_remote_code  --cache_requests true --tasks gsm8k
echo "DONE RUN (PR)"

@mergify mergify bot removed the needs-rebase label Jun 4, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants