Skip to content

Commit 148f54c

Browse files
authored
Add capturestate / rundown ETW support logging for session and provider options (#19397)
### Description Add capturestate / rundown ETW support logging for session and provider options. ### Motivation and Context Follow-up to #16259 and #18882 This is very useful when you have longer running ONNX sessions which will be the case for a lot of AI workloads. That means ETW tracing may start minutes or hours after a process & session has been established. When a trace is captured, you would want to know the state of ONNX at that time. The state for ONNX is session and config options so that they show up in the trace. Tested with xperf and ORT xperf -start ort -on 3a26b1ff-7484-7484-7484-15261f42614d xperf -capturestate ort 3a26b1ff-7484-7484-7484-15261f42614d <--- Run this after session has been up for some time xperf -stop ort -d .\ort.etl <- Trace will now also have rundown events Also these will show if you use WPR [CaptureStateOnSave ](https://learn.microsoft.com/en-us/windows-hardware/test/wpt/capturestateonsave)
1 parent 3b1b183 commit 148f54c

File tree

7 files changed

+164
-26
lines changed

7 files changed

+164
-26
lines changed

onnxruntime/core/framework/execution_providers.h

+48-7
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
#pragma once
55

6-
// #include <map>
76
#include <memory>
87
#include <string>
98
#include <unordered_map>
@@ -14,7 +13,9 @@
1413
#include "core/common/logging/logging.h"
1514
#ifdef _WIN32
1615
#include <winmeta.h>
16+
#include <evntrace.h>
1717
#include "core/platform/tracing.h"
18+
#include "core/platform/windows/telemetry.h"
1819
#endif
1920

2021
namespace onnxruntime {
@@ -44,6 +45,49 @@ class ExecutionProviders {
4445
exec_provider_options_[provider_id] = providerOptions;
4546

4647
#ifdef _WIN32
48+
LogProviderOptions(provider_id, providerOptions, false);
49+
50+
// Register callback for ETW capture state (rundown)
51+
WindowsTelemetry::RegisterInternalCallback(
52+
[this](
53+
LPCGUID SourceId,
54+
ULONG IsEnabled,
55+
UCHAR Level,
56+
ULONGLONG MatchAnyKeyword,
57+
ULONGLONG MatchAllKeyword,
58+
PEVENT_FILTER_DESCRIPTOR FilterData,
59+
PVOID CallbackContext) {
60+
(void)SourceId;
61+
(void)Level;
62+
(void)MatchAnyKeyword;
63+
(void)MatchAllKeyword;
64+
(void)FilterData;
65+
(void)CallbackContext;
66+
67+
// Check if this callback is for capturing state
68+
if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) &&
69+
((MatchAnyKeyword & static_cast<ULONGLONG>(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) {
70+
for (size_t i = 0; i < exec_providers_.size(); ++i) {
71+
const auto& provider_id = exec_provider_ids_[i];
72+
73+
auto it = exec_provider_options_.find(provider_id);
74+
if (it != exec_provider_options_.end()) {
75+
const auto& options = it->second;
76+
77+
LogProviderOptions(provider_id, options, true);
78+
}
79+
}
80+
}
81+
});
82+
#endif
83+
84+
exec_provider_ids_.push_back(provider_id);
85+
exec_providers_.push_back(p_exec_provider);
86+
return Status::OK();
87+
}
88+
89+
#ifdef _WIN32
90+
void LogProviderOptions(const std::string& provider_id, const ProviderOptions& providerOptions, bool captureState) {
4791
for (const auto& config_pair : providerOptions) {
4892
TraceLoggingWrite(
4993
telemetry_provider_handle,
@@ -52,14 +96,11 @@ class ExecutionProviders {
5296
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
5397
TraceLoggingString(provider_id.c_str(), "ProviderId"),
5498
TraceLoggingString(config_pair.first.c_str(), "Key"),
55-
TraceLoggingString(config_pair.second.c_str(), "Value"));
99+
TraceLoggingString(config_pair.second.c_str(), "Value"),
100+
TraceLoggingBool(captureState, "isCaptureState"));
56101
}
57-
#endif
58-
59-
exec_provider_ids_.push_back(provider_id);
60-
exec_providers_.push_back(p_exec_provider);
61-
return Status::OK();
62102
}
103+
#endif
63104

64105
const IExecutionProvider* Get(const onnxruntime::Node& node) const {
65106
return Get(node.GetExecutionProviderType());

onnxruntime/core/platform/windows/telemetry.cc

+19-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
#include "core/platform/windows/telemetry.h"
5+
#include "core/platform/ort_mutex.h"
56
#include "core/common/logging/logging.h"
67
#include "onnxruntime_config.h"
78

@@ -63,6 +64,8 @@ bool WindowsTelemetry::enabled_ = true;
6364
uint32_t WindowsTelemetry::projection_ = 0;
6465
UCHAR WindowsTelemetry::level_ = 0;
6566
UINT64 WindowsTelemetry::keyword_ = 0;
67+
std::vector<WindowsTelemetry::EtwInternalCallback> WindowsTelemetry::callbacks_;
68+
OrtMutex WindowsTelemetry::callbacks_mutex_;
6669

6770
WindowsTelemetry::WindowsTelemetry() {
6871
std::lock_guard<OrtMutex> lock(mutex_);
@@ -104,6 +107,11 @@ UINT64 WindowsTelemetry::Keyword() const {
104107
// return etw_status_;
105108
// }
106109

110+
void WindowsTelemetry::RegisterInternalCallback(const EtwInternalCallback& callback) {
111+
std::lock_guard<OrtMutex> lock(callbacks_mutex_);
112+
callbacks_.push_back(callback);
113+
}
114+
107115
void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback(
108116
_In_ LPCGUID SourceId,
109117
_In_ ULONG IsEnabled,
@@ -112,15 +120,21 @@ void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback(
112120
_In_ ULONGLONG MatchAllKeyword,
113121
_In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData,
114122
_In_opt_ PVOID CallbackContext) {
115-
(void)SourceId;
116-
(void)MatchAllKeyword;
117-
(void)FilterData;
118-
(void)CallbackContext;
119-
120123
std::lock_guard<OrtMutex> lock(provider_change_mutex_);
121124
enabled_ = (IsEnabled != 0);
122125
level_ = Level;
123126
keyword_ = MatchAnyKeyword;
127+
128+
InvokeCallbacks(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext);
129+
}
130+
131+
void WindowsTelemetry::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword,
132+
ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData,
133+
PVOID CallbackContext) {
134+
std::lock_guard<OrtMutex> lock(callbacks_mutex_);
135+
for (const auto& callback : callbacks_) {
136+
callback(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext);
137+
}
124138
}
125139

126140
void WindowsTelemetry::EnableTelemetryEvents() const {

onnxruntime/core/platform/windows/telemetry.h

+14-1
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,14 @@
22
// Licensed under the MIT License.
33

44
#pragma once
5+
#include <atomic>
6+
#include <vector>
7+
58
#include "core/platform/telemetry.h"
69
#include <Windows.h>
710
#include <TraceLoggingProvider.h>
811
#include "core/platform/ort_mutex.h"
912
#include "core/platform/windows/TraceLoggingConfig.h"
10-
#include <atomic>
1113

1214
namespace onnxruntime {
1315

@@ -58,16 +60,27 @@ class WindowsTelemetry : public Telemetry {
5860

5961
void LogExecutionProviderEvent(LUID* adapterLuid) const override;
6062

63+
using EtwInternalCallback = std::function<void(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level,
64+
ULONGLONG MatchAnyKeyword, ULONGLONG MatchAllKeyword,
65+
PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext)>;
66+
67+
static void RegisterInternalCallback(const EtwInternalCallback& callback);
68+
6169
private:
6270
static OrtMutex mutex_;
6371
static uint32_t global_register_count_;
6472
static bool enabled_;
6573
static uint32_t projection_;
6674

75+
static std::vector<EtwInternalCallback> callbacks_;
76+
static OrtMutex callbacks_mutex_;
6777
static OrtMutex provider_change_mutex_;
6878
static UCHAR level_;
6979
static ULONGLONG keyword_;
7080

81+
static void InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword,
82+
ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext);
83+
7184
static void NTAPI ORT_TL_EtwEnableCallback(
7285
_In_ LPCGUID SourceId,
7386
_In_ ULONG IsEnabled,

onnxruntime/core/session/inference_session.cc

+64-8
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,11 @@
4646
#include "core/optimizer/transformer_memcpy.h"
4747
#include "core/optimizer/transpose_optimization/ort_optimizer_utils.h"
4848
#include "core/platform/Barrier.h"
49-
#include "core/platform/ort_mutex.h"
5049
#include "core/platform/threadpool.h"
5150
#ifdef _WIN32
5251
#include "core/platform/tracing.h"
52+
#include <Windows.h>
53+
#include "core/platform/windows/telemetry.h"
5354
#endif
5455
#include "core/providers/cpu/controlflow/utils.h"
5556
#include "core/providers/cpu/cpu_execution_provider.h"
@@ -241,6 +242,10 @@ Status GetMinimalBuildOptimizationHandling(
241242
} // namespace
242243

243244
std::atomic<uint32_t> InferenceSession::global_session_id_{1};
245+
std::map<uint32_t, InferenceSession*> InferenceSession::active_sessions_;
246+
#ifdef _WIN32
247+
OrtMutex InferenceSession::active_sessions_mutex_; // Protects access to active_sessions_
248+
#endif
244249

245250
static Status FinalizeSessionOptions(const SessionOptions& user_provided_session_options,
246251
const ONNX_NAMESPACE::ModelProto& model_proto,
@@ -351,17 +356,47 @@ void InferenceSession::SetLoggingManager(const SessionOptions& session_options,
351356
void InferenceSession::ConstructorCommon(const SessionOptions& session_options,
352357
const Environment& session_env) {
353358
auto status = FinalizeSessionOptions(session_options, model_proto_, is_model_proto_parsed_, session_options_);
354-
// a monotonically increasing session id for use in telemetry
355-
session_id_ = global_session_id_.fetch_add(1);
356359
ORT_ENFORCE(status.IsOK(), "Could not finalize session options while constructing the inference session. Error Message: ",
357360
status.ErrorMessage());
358361

362+
// a monotonically increasing session id for use in telemetry
363+
session_id_ = global_session_id_.fetch_add(1);
364+
365+
#ifdef _WIN32
366+
std::lock_guard<OrtMutex> lock(active_sessions_mutex_);
367+
active_sessions_[global_session_id_++] = this;
368+
369+
// Register callback for ETW capture state (rundown)
370+
WindowsTelemetry::RegisterInternalCallback(
371+
[this](
372+
LPCGUID SourceId,
373+
ULONG IsEnabled,
374+
UCHAR Level,
375+
ULONGLONG MatchAnyKeyword,
376+
ULONGLONG MatchAllKeyword,
377+
PEVENT_FILTER_DESCRIPTOR FilterData,
378+
PVOID CallbackContext) {
379+
(void)SourceId;
380+
(void)Level;
381+
(void)MatchAnyKeyword;
382+
(void)MatchAllKeyword;
383+
(void)FilterData;
384+
(void)CallbackContext;
385+
386+
// Check if this callback is for capturing state
387+
if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) &&
388+
((MatchAnyKeyword & static_cast<ULONGLONG>(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) {
389+
LogAllSessions();
390+
}
391+
});
392+
#endif
393+
359394
SetLoggingManager(session_options, session_env);
360395

361396
// The call to InitLogger depends on the final state of session_options_. Hence it should be invoked
362397
// after the invocation of FinalizeSessionOptions.
363398
InitLogger(logging_manager_); // this sets session_logger_ so that it can be used for logging after this point.
364-
TraceSessionOptions(session_options);
399+
TraceSessionOptions(session_options, false);
365400

366401
#if !defined(ORT_MINIMAL_BUILD)
367402
// Update the number of steps for the graph transformer manager using the "finalized" session options
@@ -475,7 +510,9 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options,
475510
telemetry_ = {};
476511
}
477512

478-
void InferenceSession::TraceSessionOptions(const SessionOptions& session_options) {
513+
void InferenceSession::TraceSessionOptions(const SessionOptions& session_options, bool captureState) {
514+
(void)captureState; // Otherwise Linux build error
515+
479516
LOGS(*session_logger_, INFO) << session_options;
480517

481518
#ifdef _WIN32
@@ -498,7 +535,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options
498535
TraceLoggingUInt8(static_cast<UINT8>(session_options.graph_optimization_level), "graph_optimization_level"),
499536
TraceLoggingBoolean(session_options.use_per_session_threads, "use_per_session_threads"),
500537
TraceLoggingBoolean(session_options.thread_pool_allow_spinning, "thread_pool_allow_spinning"),
501-
TraceLoggingBoolean(session_options.use_deterministic_compute, "use_deterministic_compute"));
538+
TraceLoggingBoolean(session_options.use_deterministic_compute, "use_deterministic_compute"),
539+
TraceLoggingBoolean(captureState, "isCaptureState"));
502540

503541
TraceLoggingWrite(
504542
telemetry_provider_handle,
@@ -511,7 +549,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options
511549
TraceLoggingInt32(session_options.intra_op_param.dynamic_block_base_, "dynamic_block_base_"),
512550
TraceLoggingUInt32(session_options.intra_op_param.stack_size, "stack_size"),
513551
TraceLoggingString(!session_options.intra_op_param.affinity_str.empty() ? session_options.intra_op_param.affinity_str.c_str() : "", "affinity_str"),
514-
TraceLoggingBoolean(session_options.intra_op_param.set_denormal_as_zero, "set_denormal_as_zero"));
552+
TraceLoggingBoolean(session_options.intra_op_param.set_denormal_as_zero, "set_denormal_as_zero"),
553+
TraceLoggingBoolean(captureState, "isCaptureState"));
515554

516555
for (const auto& config_pair : session_options.config_options.configurations) {
517556
TraceLoggingWrite(
@@ -520,7 +559,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options
520559
TraceLoggingKeyword(static_cast<uint64_t>(onnxruntime::logging::ORTTraceLoggingKeyword::Session)),
521560
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
522561
TraceLoggingString(config_pair.first.c_str(), "Key"),
523-
TraceLoggingString(config_pair.second.c_str(), "Value"));
562+
TraceLoggingString(config_pair.second.c_str(), "Value"),
563+
TraceLoggingBoolean(captureState, "isCaptureState"));
524564
}
525565
#endif
526566
}
@@ -616,6 +656,12 @@ InferenceSession::~InferenceSession() {
616656
}
617657
}
618658

659+
// Unregister the session
660+
#ifdef _WIN32
661+
std::lock_guard<OrtMutex> lock(active_sessions_mutex_);
662+
#endif
663+
active_sessions_.erase(global_session_id_);
664+
619665
#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
620666
if (session_activity_started_)
621667
TraceLoggingWriteStop(session_activity, "OrtInferenceSessionActivity");
@@ -3070,4 +3116,14 @@ IOBinding* SessionIOBinding::Get() {
30703116
return binding_.get();
30713117
}
30723118

3119+
#ifdef _WIN32
3120+
void InferenceSession::LogAllSessions() {
3121+
std::lock_guard<OrtMutex> lock(active_sessions_mutex_);
3122+
for (const auto& session_pair : active_sessions_) {
3123+
InferenceSession* session = session_pair.second;
3124+
TraceSessionOptions(session->session_options_, true);
3125+
}
3126+
}
3127+
#endif
3128+
30733129
} // namespace onnxruntime

onnxruntime/core/session/inference_session.h

+12-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
#pragma once
55

6+
#include <map>
67
#include <optional>
78
#include <string>
89
#include <unordered_map>
@@ -21,11 +22,12 @@
2122
#include "core/framework/session_state.h"
2223
#include "core/framework/tuning_results.h"
2324
#include "core/framework/framework_provider_common.h"
25+
#include "core/framework/session_options.h"
2426
#include "core/graph/basic_types.h"
2527
#include "core/optimizer/graph_transformer_level.h"
2628
#include "core/optimizer/graph_transformer_mgr.h"
2729
#include "core/optimizer/insert_cast_transformer.h"
28-
#include "core/framework/session_options.h"
30+
#include "core/platform/ort_mutex.h"
2931
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
3032
#include "core/language_interop_ops/language_interop_ops.h"
3133
#endif
@@ -119,6 +121,10 @@ class InferenceSession {
119121
};
120122

121123
using InputOutputDefMetaMap = InlinedHashMap<std::string_view, InputOutputDefMetaData>;
124+
static std::map<uint32_t, InferenceSession*> active_sessions_;
125+
#ifdef _WIN32
126+
static OrtMutex active_sessions_mutex_; // Protects access to active_sessions_
127+
#endif
122128

123129
public:
124130
#if !defined(ORT_MINIMAL_BUILD)
@@ -642,7 +648,7 @@ class InferenceSession {
642648

643649
void InitLogger(logging::LoggingManager* logging_manager);
644650

645-
void TraceSessionOptions(const SessionOptions& session_options);
651+
void TraceSessionOptions(const SessionOptions& session_options, bool captureState);
646652

647653
[[nodiscard]] common::Status CheckShapes(const std::string& input_name, const TensorShape& input_shape,
648654
const TensorShape& expected_shape, const char* input_output_moniker) const;
@@ -679,6 +685,10 @@ class InferenceSession {
679685
*/
680686
void ShrinkMemoryArenas(gsl::span<const AllocatorPtr> arenas_to_shrink);
681687

688+
#ifdef _WIN32
689+
void LogAllSessions();
690+
#endif
691+
682692
#if !defined(ORT_MINIMAL_BUILD)
683693
virtual common::Status AddPredefinedTransformers(
684694
GraphTransformerManager& transformer_manager,

onnxruntime/core/session/provider_registration.cc

+4
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
9090
(std::string(provider_name) + " execution provider is not supported in this build. ").c_str());
9191
};
9292

93+
for (const auto& config_pair : provider_options) {
94+
ORT_THROW_IF_ERROR(options->value.config_options.AddConfigEntry((std::string(provider_name) + ":" + config_pair.first).c_str(), config_pair.second.c_str()));
95+
}
96+
9397
if (strcmp(provider_name, "DML") == 0) {
9498
#if defined(USE_DML)
9599
options->provider_factories.push_back(DMLProviderFactoryCreator::CreateFromProviderOptions(provider_options));

0 commit comments

Comments
 (0)