-
-
Notifications
You must be signed in to change notification settings - Fork 7.8k
add causal-conv1d in Triton and integrate into vLLM with test code #18218
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you report some lm_eval scores running gsm8k, as well as make sure it runs correctly without --enforce-eager
?
vllm/model_executor/models/bamba.py
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like the changes to bamba.py
, granitemoehybrid.py
, mamba2.py
and zamba2.py
are pretty spurious. Could you revert those?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the Mamba2's metadata, there are setting that get updated once (at model initialization), and there are settings that get updated at every input (while keeping the same across layers).
Adding self.mamba2_metadata
provides a solution to reuse updated-once
data. If you don't like this level of optimization, please let me know @tlrmchlsmth. This is optional. I can revert the changes to the other models, and keep the change only on bamba.py
.
path = os.environ.get("VLLM_USE_TRITON_CONV1D", None) | ||
if path is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we add an environment variable, this should be handled in vllm/envs.py
.
However, generally we should avoid adding an environment variable where possible. Which cases should we be using the triton conv1d kernels vs CUDA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So far, the testing shows CUDA lower overhead is better at short context length (~300-400 tokens, and short output length); other than that, Triton-kernel is better. The e2e slower doesn't come from the kernel itself, but the overhead launch in overall. I just want to maintain the two pathways before a final decision is made, based on other testing outside what I have done.
I'll provide more details of test cases in the PR threads.
While in theory, some Triton launch overhead can be reduced using Triton JIT cache mechanism, it is not tested here.
[torch.float32, torch.float16, torch.bfloat16]) | ||
@pytest.mark.parametrize("silu_activation", [False, True]) | ||
@pytest.mark.parametrize("has_bias", [False, True]) | ||
@pytest.mark.parametrize("seqlen", [1]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we test more than just seqlen 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yes, I can add more to the test code
"""Create a weight loader for mamba v2. This ensures that the projections | ||
are correctly sharded so that they can be split into x, B, C. It also | ||
ensures that all the groups corresponding to a head shard is placed | ||
"""Create a weight loader for mamba v2. This ensures that the projections | ||
are correctly sharded so that they can be split into x, B, C. It also | ||
ensures the the all the groups corresponding to a head shard is placed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: please revert the the
back to that
batch_ptr = torch.full( | ||
(MAX_NUM_PROGRAMS, ), PAD_SLOT_ID, dtype=torch.int32, | ||
device='cpu') # tracking which seq-idx the Triton program is handling | ||
token_chunk_offset_ptr = torch.full( | ||
(MAX_NUM_PROGRAMS, ), PAD_SLOT_ID, dtype=torch.int32, device='cpu' | ||
) # tracking BLOCK_M-based index in the sequence the Triton program is handling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are these stateful global variables? Would it be better for these to go in the mamba2_metadata instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explain: the parallelism of the causal-conv1d kernel (prefill stage) is 3D: batch, feature, and seqlen dimensions. It means that each Triton program can handle a group of token in a sequence. The information about token range (start/stop) at each sequence, is tracked by these two tensors. In vLLM, the tensors are generated as part of the metadata construct, which means at e2e runtime, it doesn't use these.
However, for kernel-level runtime, e.g. microbenchmarking or kernel testing; they needs to be provided. Here, we have 3 choices:
- created a metadata object just like vLLM e2e setting
- when metadata is not provided, either we declare inside the function that use it (which create the tensor each time a kernel get invoked), or declare at the module-level (created once at module loading - safe some overhead of tensor allocation).
Please let me know what you think it should be revised @tlrmchlsmth .
num_cache_lines: Optional[int] = None | ||
stride_istate_seq: Optional[int] = None | ||
stride_istate_dim: Optional[int] = None | ||
stride_istate_token: Optional[int] = None | ||
seqlens: Optional[np.ndarray] = None | ||
padded_batch: Optional[int] = None | ||
nums_dict: Optional[dict] = None | ||
is_channel_last: bool = True | ||
stride_w_dim: Optional[int] = None | ||
stride_w_width: Optional[int] = None | ||
width: Optional[int] = None | ||
np2_statelen: Optional[int] = None | ||
stride_x_seq: Optional[int] = 0 | ||
stride_x_dim: Optional[int] = None | ||
stride_x_token: Optional[int] = None | ||
dim: Optional[int] = None | ||
cu_seqlen: Optional[int] = None | ||
out: Optional[torch.Tensor] = None | ||
stride_o_seq: Optional[int] = 0 | ||
stride_o_dim: Optional[int] = None | ||
stride_o_token: Optional[int] = None | ||
MAX_NUM_PROGRAMS: int = 1024 | ||
batch_ptr: Optional[torch.tensor] = None | ||
token_chunk_offset_ptr: Optional[torch.tensor] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a lot of stuff in here, and it's not really clear what most of it is for. At first glance it seems like most of this should be accessed on the fly instead of stored in the metadata here. Could you take a stab at cleaning this up?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The information here is reused across Mamba layers, even the stride call would trigger Torch calls, which triggers an unnecessary overhead. I can adds a description as needed. Please let me know what you think @tlrmchlsmth
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think most of this should be removed. For CPU overheads in decode, we can rely on CUDA graphs and for prefill they are amortized
# NOTE: currently it is assumed prefill requests come before decode requests -> we can use ':num_prefills' slicing | ||
# TODO: maybe revert back to the original code (below) if above no longer holds | ||
# has_initial_states = attn_metadata.context_lens_tensor > 0 | ||
# zero_init_indices = mamba_cache_params.state_indices_tensor[~has_initial_states] | ||
# mamba_cache_params.ssm_state[zero_init_indices] = 0 | ||
# initial_states = mamba_cache_params.ssm_state[mamba_cache_params.state_indices_tensor] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can rely on batch reordering and require that it be used for this implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The change from a previous PR (cuda-split) implies this assumption, this PR doesn't have this assumption which makes is more suitable for vLLM v1 design, the comment I added here is to clarify the code path from the previous PR. I can remove the comment as needed. Please let me know @tlrmchlsmth
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, please remove the comment.
We can rely on batch reordering even in vLLM V1, so this is a non issue
def is_conv_in_Triton(self): | ||
import os | ||
path = os.environ.get("VLLM_USE_TRITON_CONV1D", None) | ||
if path is not None: | ||
print("mamba_mixer2 - VLLM_USE_TRITON_CONV1D") | ||
return True | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few things here:
- Remove the print before landing
- env variables should be defined in
vllm/envs.py
- The function would be better-named
use_triton_causal_conv_1d
(more descriptive & proper capitalization) - We should work hard to avoid proliferation of environment variables. Could you come up with a reliable heuristic to choose between this triton implementation and the CUDA kernel instead of adding an env that exposes complexity to the end-user?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I update the code accordingly. The Triton version use the layout of conv_state
cache that is the same layout with input tensor hidden_states_B_C
, i.e. contiguous along feature-dimension. The CUDA version in vLLM use a different layout of conv_state
cache that is contiguous along sequence-dimension. So, there is no good way to switch from one to another in the same session. It's the choice to be made, based on the expected workload. The benefit of Mamba-based model is in long context length and long output response, which we shows that Triton version would take over. Also, the Triton kernel follows vLLM v1 design, allowing prefill requests and decode requests to be mixed.
RESULT: long-context latency measurement (compared with Llama)
python benchmarks/benchmark_latency.py --model /net/storage149/autofs/css22/nmg/models/hf/meta-llama/Llama-3.1-8B-Instruct/main --input-len=131072 --output-len=1 --batch-size=1 --max_num_batched_tokens=2048
Avg latency: 11.214325266804856 seconds
10% percentile latency: 11.202042526123114 seconds
25% percentile latency: 11.206939334078925 seconds
50% percentile latency: 11.212064623483457 seconds
75% percentile latency: 11.220630767958937 seconds
90% percentile latency: 11.2278619370074 seconds
99% percentile latency: 11.24528882814222 seconds
- Main branch
# Default (max_num_batched_tokens=2048)
python benchmarks/benchmark_latency.py --model ibm-ai-platform/Bamba-9B-v2 --input-len=131072 --output-len=1 --batch-size=1
Avg latency: 6.231618080474436 seconds
10% percentile latency: 6.204561746446416 seconds
25% percentile latency: 6.216710253240308 seconds
50% percentile latency: 6.219352831016295 seconds
75% percentile latency: 6.223808606999228 seconds
90% percentile latency: 6.227424801979214 seconds
99% percentile latency: 6.519547982601217 seconds
- Current PR:
export VLLM_USE_TRITON_CONV1D="1"
python benchmarks/benchmark_latency.py --model ibm-ai-platform/Bamba-9B-v2 --input-len=131072 --output-len=1 --batch-size=1
Avg latency: 5.757278195097267 seconds
10% percentile latency: 5.734804188809358 seconds
25% percentile latency: 5.739403567742556 seconds
50% percentile latency: 5.743007940007374 seconds
75% percentile latency: 5.748229099262971 seconds
90% percentile latency: 5.751799210254103 seconds
99% percentile latency: 6.068630096325651 seconds
Co-authored-by: Tyler Michael Smith <[email protected]>
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
GSM8K RESULT
COMMAND TO RUN:
|
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
Signed-off-by: Tuan M. Hoang-Trong <[email protected]>
This PR adds Triton-based causal-conv1d, making Mamba-based models in vLLM
There are two kernels implemented
It also performs better than the CUDA-split pathway which was merged as PR #17146.

[data benchmarking runtime processing the same batch of mixed requests, first send the batch to the single Triton kernel, and then using CUDA-split pathway where requests are first separated, with prefill-only requests are sent to one kernel, and decode-only requests are sent to the second kernel]
ALGORITHMIC CHOICE OF TRITON KERNEL: Unlike CUDA kernel which is implemented with parallelism in 2D, i.e. along feature-dimension, and batch size; Triton kernel is implemented with parallelism in 3D, i.e. along also sequence-dimension. Also, the Triton kernels don't make any changes to the layout of the input data which is contiguous along the feature-dimension. Another key difference is that Triton kernels expect the conv-state to be contiguous along the feature-dimension, while in existing CUDA implementation, it expects the conv-state cache to be contiguous along the kernel-width (i.e. sequence-length) axis. Nevertheless, the two CUDA kernels are not compatible with the layout of conv-state cache, and therefore prevents the efficient processing in decode-only requests or mixed prefill/decode-requests.
Also, some other improvement in reducing overhead is incorporated.
Even though binary code generated from Triton is faster, the launch overhead is a known issue and is therefore need further optimization to get the E2E Triton-only Mamba models in vLLM performant. Here, we also incorporate such improvements by using a metadata that can be reused across layers.
In our benchmark on ShareGPT dataset which has short input prompt (a few hundreds of tokens)
default setting: generates short number of tokens (i.e. 256 tokens) CUDA-backed Bamba (
ibm-ai-platform/Bamba-9B
) is still faster; 10% slower (total token throughput) yet only 2% in output token throughput and 2% in TTFT.generating 1024 tokens: Triton-backed Bamba is now faster with 5% faster on token throughput; and with 11% faster on TTFT. The benefit of faster Triton kernels now exceeds the overall costs of Triton launch overhead.
In the longer context length and/or longer number of generated tokens, Triton-only Mamba-based model is expected to be better than CUDA-split approach. However, the PR maintains the existing CUDA pathway as the default one until it is adopted by vLLM maintainers. Currently, the code is added as an optional pathway to the CUDA-split via
VLLM_USE_TRITON_CONV1D
environment variable set to 1.This is one step closer to be compatible with vLLM v1 design, i.e. without splitting the batch into prefill-only and decode-only for CUDA-split processing.
Test code is also added