168
168
169
169
bool support_simdgroup_reduction;
170
170
bool support_simdgroup_mm;
171
+
172
+ bool should_capture_next_compute;
171
173
};
172
174
173
175
// MSL code
@@ -687,6 +689,20 @@ static bool ggml_metal_graph_compute(
687
689
const int n_cb = ctx->n_cb ;
688
690
const int n_nodes_per_cb = (n_nodes + n_cb - 1 ) / n_cb;
689
691
692
+ const bool should_capture = ctx->should_capture_next_compute ;
693
+ if (should_capture) {
694
+ ctx->should_capture_next_compute = false ;
695
+
696
+ MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new ];
697
+ descriptor.captureObject = ctx->queue ;
698
+
699
+ NSError * error = nil ;
700
+ if (![[MTLCaptureManager sharedCaptureManager ] startCaptureWithDescriptor: descriptor error: &error]) {
701
+ GGML_METAL_LOG_ERROR (" %s : error: unable to start capture '%s '\n " , __func__, [[error localizedDescription ] UTF8String ]);
702
+ GGML_ASSERT (!" capture failed" );
703
+ }
704
+ }
705
+
690
706
id <MTLCommandBuffer > command_buffer_builder[n_cb];
691
707
for (int cb_idx = 0 ; cb_idx < n_cb; ++cb_idx) {
692
708
id <MTLCommandBuffer > command_buffer = [ctx->queue commandBufferWithUnretainedReferences ];
@@ -695,6 +711,7 @@ static bool ggml_metal_graph_compute(
695
711
// enqueue the command buffers in order to specify their execution order
696
712
[command_buffer enqueue ];
697
713
}
714
+
698
715
const id <MTLCommandBuffer > *command_buffers = command_buffer_builder;
699
716
700
717
dispatch_apply (n_cb, ctx->d_queue , ^(size_t iter) {
@@ -741,9 +758,9 @@ static bool ggml_metal_graph_compute(
741
758
GGML_ASSERT (!" unsupported op" );
742
759
}
743
760
744
- # ifndef GGML_METAL_NDEBUG
745
- [encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (dst) encoding: NSUTF8StringEncoding]];
746
- # endif
761
+ if (should_capture) {
762
+ [encoder pushDebugGroup: [NSString stringWithCString: ggml_op_desc (dst) encoding: NSUTF8StringEncoding]];
763
+ }
747
764
748
765
const int64_t ne00 = src0 ? src0->ne [0 ] : 0 ;
749
766
const int64_t ne01 = src0 ? src0->ne [1 ] : 0 ;
@@ -2218,9 +2235,9 @@ static bool ggml_metal_graph_compute(
2218
2235
}
2219
2236
}
2220
2237
2221
- # ifndef GGML_METAL_NDEBUG
2222
- [encoder popDebugGroup ];
2223
- # endif
2238
+ if (should_capture) {
2239
+ [encoder popDebugGroup ];
2240
+ }
2224
2241
}
2225
2242
2226
2243
[encoder endEncoding ];
@@ -2242,6 +2259,10 @@ static bool ggml_metal_graph_compute(
2242
2259
}
2243
2260
}
2244
2261
2262
+ if (should_capture) {
2263
+ [[MTLCaptureManager sharedCaptureManager ] stopCapture ];
2264
+ }
2265
+
2245
2266
return true ;
2246
2267
}
2247
2268
@@ -2613,6 +2634,13 @@ bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
2613
2634
return [ctx->device supportsFamily: (MTLGPUFamilyApple1 + family - 1 )];
2614
2635
}
2615
2636
2637
+ void ggml_backend_metal_capture_next_compute (ggml_backend_t backend) {
2638
+ GGML_ASSERT (ggml_backend_is_metal (backend));
2639
+
2640
+ struct ggml_metal_context * ctx = (struct ggml_metal_context *)backend->context ;
2641
+ ctx->should_capture_next_compute = true ;
2642
+ }
2643
+
2616
2644
GGML_CALL ggml_backend_t ggml_backend_reg_metal_init (const char * params, void * user_data); // silence warning
2617
2645
2618
2646
GGML_CALL ggml_backend_t ggml_backend_reg_metal_init (const char * params, void * user_data) {
0 commit comments