46
46
#include " core/optimizer/transformer_memcpy.h"
47
47
#include " core/optimizer/transpose_optimization/ort_optimizer_utils.h"
48
48
#include " core/platform/Barrier.h"
49
- #include " core/platform/ort_mutex.h"
50
49
#include " core/platform/threadpool.h"
51
50
#ifdef _WIN32
52
51
#include " core/platform/tracing.h"
52
+ #include < Windows.h>
53
+ #include " core/platform/windows/telemetry.h"
53
54
#endif
54
55
#include " core/providers/cpu/controlflow/utils.h"
55
56
#include " core/providers/cpu/cpu_execution_provider.h"
@@ -241,6 +242,10 @@ Status GetMinimalBuildOptimizationHandling(
241
242
} // namespace
242
243
243
244
std::atomic<uint32_t > InferenceSession::global_session_id_{1 };
245
+ std::map<uint32_t , InferenceSession*> InferenceSession::active_sessions_;
246
+ #ifdef _WIN32
247
+ OrtMutex InferenceSession::active_sessions_mutex_; // Protects access to active_sessions_
248
+ #endif
244
249
245
250
static Status FinalizeSessionOptions (const SessionOptions& user_provided_session_options,
246
251
const ONNX_NAMESPACE::ModelProto& model_proto,
@@ -351,17 +356,47 @@ void InferenceSession::SetLoggingManager(const SessionOptions& session_options,
351
356
void InferenceSession::ConstructorCommon (const SessionOptions& session_options,
352
357
const Environment& session_env) {
353
358
auto status = FinalizeSessionOptions (session_options, model_proto_, is_model_proto_parsed_, session_options_);
354
- // a monotonically increasing session id for use in telemetry
355
- session_id_ = global_session_id_.fetch_add (1 );
356
359
ORT_ENFORCE (status.IsOK (), " Could not finalize session options while constructing the inference session. Error Message: " ,
357
360
status.ErrorMessage ());
358
361
362
+ // a monotonically increasing session id for use in telemetry
363
+ session_id_ = global_session_id_.fetch_add (1 );
364
+
365
+ #ifdef _WIN32
366
+ std::lock_guard<OrtMutex> lock (active_sessions_mutex_);
367
+ active_sessions_[global_session_id_++] = this ;
368
+
369
+ // Register callback for ETW capture state (rundown)
370
+ WindowsTelemetry::RegisterInternalCallback (
371
+ [this ](
372
+ LPCGUID SourceId,
373
+ ULONG IsEnabled,
374
+ UCHAR Level,
375
+ ULONGLONG MatchAnyKeyword,
376
+ ULONGLONG MatchAllKeyword,
377
+ PEVENT_FILTER_DESCRIPTOR FilterData,
378
+ PVOID CallbackContext) {
379
+ (void )SourceId;
380
+ (void )Level;
381
+ (void )MatchAnyKeyword;
382
+ (void )MatchAllKeyword;
383
+ (void )FilterData;
384
+ (void )CallbackContext;
385
+
386
+ // Check if this callback is for capturing state
387
+ if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) &&
388
+ ((MatchAnyKeyword & static_cast <ULONGLONG>(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0 )) {
389
+ LogAllSessions ();
390
+ }
391
+ });
392
+ #endif
393
+
359
394
SetLoggingManager (session_options, session_env);
360
395
361
396
// The call to InitLogger depends on the final state of session_options_. Hence it should be invoked
362
397
// after the invocation of FinalizeSessionOptions.
363
398
InitLogger (logging_manager_); // this sets session_logger_ so that it can be used for logging after this point.
364
- TraceSessionOptions (session_options);
399
+ TraceSessionOptions (session_options, false );
365
400
366
401
#if !defined(ORT_MINIMAL_BUILD)
367
402
// Update the number of steps for the graph transformer manager using the "finalized" session options
@@ -475,7 +510,9 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options,
475
510
telemetry_ = {};
476
511
}
477
512
478
- void InferenceSession::TraceSessionOptions (const SessionOptions& session_options) {
513
+ void InferenceSession::TraceSessionOptions (const SessionOptions& session_options, bool captureState) {
514
+ (void )captureState; // Otherwise Linux build error
515
+
479
516
LOGS (*session_logger_, INFO) << session_options;
480
517
481
518
#ifdef _WIN32
@@ -498,7 +535,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options
498
535
TraceLoggingUInt8 (static_cast <UINT8>(session_options.graph_optimization_level ), " graph_optimization_level" ),
499
536
TraceLoggingBoolean (session_options.use_per_session_threads , " use_per_session_threads" ),
500
537
TraceLoggingBoolean (session_options.thread_pool_allow_spinning , " thread_pool_allow_spinning" ),
501
- TraceLoggingBoolean (session_options.use_deterministic_compute , " use_deterministic_compute" ));
538
+ TraceLoggingBoolean (session_options.use_deterministic_compute , " use_deterministic_compute" ),
539
+ TraceLoggingBoolean (captureState, " isCaptureState" ));
502
540
503
541
TraceLoggingWrite (
504
542
telemetry_provider_handle,
@@ -511,7 +549,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options
511
549
TraceLoggingInt32 (session_options.intra_op_param .dynamic_block_base_ , " dynamic_block_base_" ),
512
550
TraceLoggingUInt32 (session_options.intra_op_param .stack_size , " stack_size" ),
513
551
TraceLoggingString (!session_options.intra_op_param .affinity_str .empty () ? session_options.intra_op_param .affinity_str .c_str () : " " , " affinity_str" ),
514
- TraceLoggingBoolean (session_options.intra_op_param .set_denormal_as_zero , " set_denormal_as_zero" ));
552
+ TraceLoggingBoolean (session_options.intra_op_param .set_denormal_as_zero , " set_denormal_as_zero" ),
553
+ TraceLoggingBoolean (captureState, " isCaptureState" ));
515
554
516
555
for (const auto & config_pair : session_options.config_options .configurations ) {
517
556
TraceLoggingWrite (
@@ -520,7 +559,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options
520
559
TraceLoggingKeyword (static_cast <uint64_t >(onnxruntime::logging::ORTTraceLoggingKeyword::Session)),
521
560
TraceLoggingLevel (WINEVENT_LEVEL_INFO),
522
561
TraceLoggingString (config_pair.first .c_str (), " Key" ),
523
- TraceLoggingString (config_pair.second .c_str (), " Value" ));
562
+ TraceLoggingString (config_pair.second .c_str (), " Value" ),
563
+ TraceLoggingBoolean (captureState, " isCaptureState" ));
524
564
}
525
565
#endif
526
566
}
@@ -616,6 +656,12 @@ InferenceSession::~InferenceSession() {
616
656
}
617
657
}
618
658
659
+ // Unregister the session
660
+ #ifdef _WIN32
661
+ std::lock_guard<OrtMutex> lock (active_sessions_mutex_);
662
+ #endif
663
+ active_sessions_.erase (global_session_id_);
664
+
619
665
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
620
666
if (session_activity_started_)
621
667
TraceLoggingWriteStop (session_activity, " OrtInferenceSessionActivity" );
@@ -3070,4 +3116,14 @@ IOBinding* SessionIOBinding::Get() {
3070
3116
return binding_.get ();
3071
3117
}
3072
3118
3119
+ #ifdef _WIN32
3120
+ void InferenceSession::LogAllSessions () {
3121
+ std::lock_guard<OrtMutex> lock (active_sessions_mutex_);
3122
+ for (const auto & session_pair : active_sessions_) {
3123
+ InferenceSession* session = session_pair.second ;
3124
+ TraceSessionOptions (session->session_options_ , true );
3125
+ }
3126
+ }
3127
+ #endif
3128
+
3073
3129
} // namespace onnxruntime
0 commit comments