@@ -1240,15 +1240,20 @@ def fused_experts_impl(hidden_states: torch.Tensor,
1240
1240
1241
1241
config = get_config_func (M )
1242
1242
1243
- intermediate_cache1 = torch .empty ((M , top_k_num , N ),
1244
- device = hidden_states .device ,
1245
- dtype = hidden_states .dtype )
1243
+ # We can reuse the memory between these because by the time we need
1244
+ # cache3, we're done with cache1
1245
+ cache13 = torch .empty (M * top_k_num * max (N , w2 .shape [1 ]),
1246
+ device = hidden_states .device ,
1247
+ dtype = hidden_states .dtype )
1248
+ intermediate_cache1 = cache13 [:M * top_k_num * N ].view (
1249
+ (M , topk_ids .shape [1 ], N ))
1250
+ intermediate_cache3 = cache13 [:M * top_k_num * w2 .shape [1 ]].view (
1251
+ (M , topk_ids .shape [1 ], w2 .shape [1 ]))
1252
+
1253
+ # This needs separate memory since it's used concurrently with cache1
1246
1254
intermediate_cache2 = torch .empty ((M * top_k_num , N // 2 ),
1247
1255
device = hidden_states .device ,
1248
1256
dtype = hidden_states .dtype )
1249
- intermediate_cache3 = torch .empty ((M , top_k_num , w2 .shape [1 ]),
1250
- device = hidden_states .device ,
1251
- dtype = hidden_states .dtype )
1252
1257
1253
1258
if hidden_states .dtype == torch .bfloat16 :
1254
1259
compute_type = tl .bfloat16
0 commit comments