@@ -67,25 +67,36 @@ def _mamba_chunk_scan_combined_fwd(x,
67
67
D = D .contiguous ()
68
68
if initial_states is not None :
69
69
assert initial_states .shape == (batch , nheads , headdim , dstate )
70
- # # (batch, nchunks, chunk_size, chunk_size) or (batch, nchunks, nheads, chunk_size, chunk_size)
71
- # dA_cumsum_tmp0, dt_tmp0 = _chunk_cumsum_fwd(dt[:, :147], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
72
- # dA_cumsum_tmp1, dt_tmp1 = _chunk_cumsum_fwd(dt[:, 147:], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
73
- # dA_cumsum_tmp2, dt_tmp2 = _chunk_cumsum_fwd(dt[:, 147:256], A, chunk_size, dt_bias=dt_bias, dt_softplus=dt_softplus)
70
+
71
+ # This function executes 5 sub-functions for computing mamba
72
+ # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/
73
+ # which has a minimal implementation to understand the below operations
74
+ # - as explained by the blog, mamba is a special case of causal attention
75
+ # - the idea is to chunk the attention matrix and compute each
76
+ # submatrix separately using different optimizations.
77
+ # - see the blog and paper for a visualization of the submatrices
78
+ # which we refer to in the comments below
79
+
80
+ # 1. Compute chunked cumsum of A * dt
81
+ # - here dt may go through a softplus activation
74
82
dA_cumsum , dt = _chunk_cumsum_fwd (dt ,
75
83
A ,
76
84
chunk_size ,
77
85
dt_bias = dt_bias ,
78
86
dt_softplus = dt_softplus ,
79
87
dt_limit = dt_limit )
88
+
89
+ # 2. Compute the state for each intra-chunk
90
+ # (right term of low-rank factorization of off-diagonal blocks; B terms)
80
91
states = _chunk_state_fwd (B ,
81
92
x ,
82
93
dt ,
83
94
dA_cumsum ,
84
95
seq_idx = seq_idx ,
85
96
states_in_fp32 = True )
86
- # states_tmp0 = _chunk_state_fwd(B[:, :147], x[:, :147], dt_tmp0, dA_cumsum_tmp0, states_in_fp32=True)
87
- # states_tmp1 = _chunk_state_fwd(B[:, 147:], x[:, 147:], dt_tmp1, dA_cumsum_tmp1, states_in_fp32=True)
88
- # states_tmp2 = _chunk_state_fwd(B[:, 147:256], x[:, 147:256], dt_tmp2, dA_cumsum_tmp2, states_in_fp32=True )
97
+
98
+ # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
99
+ # (middle term of factorization of off-diag blocks; A terms )
89
100
states , final_states = _state_passing_fwd (
90
101
rearrange (states , "... p n -> ... (p n)" ),
91
102
dA_cumsum [:, :, :, - 1 ],
@@ -96,13 +107,16 @@ def _mamba_chunk_scan_combined_fwd(x,
96
107
out_dtype = C .dtype )
97
108
states , final_states = (rearrange (t , "... (p n) -> ... p n" , n = dstate )
98
109
for t in [states , final_states ])
99
- # states_tmp0 = rearrange(_state_passing_fwd(rearrange(states_tmp0, "... p n -> ... (p n)"), dA_cumsum_tmp0[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
100
- # states_tmp1 = rearrange(_state_passing_fwd(rearrange(states_tmp1, "... p n -> ... (p n)"), dA_cumsum_tmp1[:, :, :, -1], chunk_size=chunk_size), "... (p n) -> ... p n", n=dstate)
110
+
111
+ # 4. Compute batched matrix multiply for C_j^T B_i terms
101
112
CB = _bmm_chunk_fwd (C ,
102
113
B ,
103
114
chunk_size ,
104
115
seq_idx = seq_idx ,
105
116
output_dtype = torch .float32 )
117
+
118
+ # 5. Scan and compute the diagonal blocks, taking into
119
+ # account past causal states.
106
120
out , out_x = _chunk_scan_fwd (CB ,
107
121
x ,
108
122
dt ,
0 commit comments