Skip to content

Commit 9ad9e20

Browse files
committed
more comments
Signed-off-by: Yu Chin Fabian Lim <[email protected]>
1 parent b2dc5ca commit 9ad9e20

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

tests/models/decoder_only/language/test_bamba.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
# choke on the model_kwarg 'attention_mask' if hf_model.generate_greedy is used.
2121
def generate_greedy(model_name, example_prompts, max_tokens):
2222
# Create a text generation pipeline
23-
# - in the original test_mamba.py they do not put the model to cuda
24-
# maybe this affects the test.
2523
tokenizer = AutoTokenizer.from_pretrained(model_name)
2624
model = AutoModelForCausalLM.from_pretrained(model_name)
2725

vllm/model_executor/layers/mamba/ops/ssd_combined.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,25 +67,36 @@ def _mamba_chunk_scan_combined_fwd(x,
6767
D = D.contiguous()
6868
if initial_states is not None:
6969
assert initial_states.shape == (batch, nheads, headdim, dstate)
70-
# # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)
71-
# dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
72-
# dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
73-
# dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
70+
71+
# This function executes 5 sub-functions for computing mamba
72+
# - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
73+
# which has a minimal implementation to understand the below operations
74+
# - as explained by the blog, mamba is a special case of causal attention
75+
# - the idea is to chunk the attention matrix and compute each
76+
# submatrix separately using different optimizations.
77+
# - see the blog and paper for a visualization of the submatrices
78+
# which we refer to in the comments below
79+
80+
# 1. Compute chunked cumsum of A * dt
81+
# - here dt may go through a softplus activation
7482
dA_cumsum, dt = _chunk_cumsum_fwd(dt,
7583
A,
7684
chunk_size,
7785
dt_bias=dt_bias,
7886
dt_softplus=dt_softplus,
7987
dt_limit=dt_limit)
88+
89+
# 2. Compute the state for each intra-chunk
90+
# (right term of low-rank factorization of off-diagonal blocks; B terms)
8091
states = _chunk_state_fwd(B,
8192
x,
8293
dt,
8394
dA_cumsum,
8495
seq_idx=seq_idx,
8596
states_in_fp32=True)
86-
# states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
87-
# states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
88-
# states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True)
97+
98+
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
99+
# (middle term of factorization of off-diag blocks; A terms)
89100
states, final_states = _state_passing_fwd(
90101
rearrange(states, "... p n -> ... (p n)"),
91102
dA_cumsum[:, :, :, -1],
@@ -96,13 +107,16 @@ def _mamba_chunk_scan_combined_fwd(x,
96107
out_dtype=C.dtype)
97108
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
98109
for t in [states, final_states])
99-
# states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
100-
# states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
110+
111+
# 4. Compute batched matrix multiply for C_j^T B_i terms
101112
CB = _bmm_chunk_fwd(C,
102113
B,
103114
chunk_size,
104115
seq_idx=seq_idx,
105116
output_dtype=torch.float32)
117+
118+
# 5. Scan and compute the diagonal blocks, taking into
119+
# account past causal states.
106120
out, out_x = _chunk_scan_fwd(CB,
107121
x,
108122
dt,

0 commit comments

Comments
 (0)