Skip to content

Commit 7a84bf1

Browse files
committed
Fix delay load for WebGPU EP and DML EP
1 parent 3a0b958 commit 7a84bf1

14 files changed

+369
-49
lines changed

cmake/onnxruntime.cmake

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ if(WIN32)
7777
onnxruntime_add_shared_library(onnxruntime
7878
${SYMBOL_FILE}
7979
"${ONNXRUNTIME_ROOT}/core/dll/dllmain.cc"
80+
"${ONNXRUNTIME_ROOT}/core/dll/delay_load_hook.cc"
8081
"${ONNXRUNTIME_ROOT}/core/dll/onnxruntime.rc"
8182
)
8283
elseif(onnxruntime_BUILD_APPLE_FRAMEWORK)

cmake/onnxruntime_nodejs.cmake

+4
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,8 @@ add_custom_target(nodejs_binding_wrapper ALL
103103
add_dependencies(js_common_npm_ci js_npm_ci)
104104
add_dependencies(nodejs_binding_wrapper js_common_npm_ci)
105105
add_dependencies(nodejs_binding_wrapper onnxruntime)
106+
if (WIN32 AND onnxruntime_USE_WEBGPU)
107+
add_dependencies(nodejs_binding_wrapper copy_dxil_dll)
108+
add_dependencies(nodejs_binding_wrapper dxcompiler)
109+
endif()
106110
endif()

cmake/onnxruntime_unittests.cmake

+12
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,9 @@ set (onnxruntime_global_thread_pools_test_SRC
525525
set (onnxruntime_webgpu_external_dawn_test_SRC
526526
${TEST_SRC_DIR}/webgpu/external_dawn/main.cc)
527527

528+
set (onnxruntime_webgpu_delay_load_test_SRC
529+
${TEST_SRC_DIR}/webgpu/delay_load/main.cc)
530+
528531
# tests from lowest level library up.
529532
# the order of libraries should be maintained, with higher libraries being added first in the list
530533

@@ -1864,4 +1867,13 @@ if (onnxruntime_USE_WEBGPU AND onnxruntime_USE_EXTERNAL_DAWN)
18641867
onnxruntime_add_include_to_target(onnxruntime_webgpu_external_dawn_test dawn::dawncpp_headers dawn::dawn_headers)
18651868
endif()
18661869

1870+
if (onnxruntime_USE_WEBGPU AND WIN32 AND onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT onnxruntime_MINIMAL_BUILD)
1871+
AddTest(DYN
1872+
TARGET onnxruntime_webgpu_delay_load_test
1873+
SOURCES ${onnxruntime_webgpu_delay_load_test_SRC}
1874+
LIBS ${SYS_PATH_LIB}
1875+
DEPENDS ${all_dependencies}
1876+
)
1877+
endif()
1878+
18671879
include(onnxruntime_fuzz_test.cmake)

js/node/CMakeLists.txt

+8
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,14 @@ if (WIN32)
117117
file(COPY ${ONNXRUNTIME_WIN_BIN_DIR}/DirectML.dll
118118
DESTINATION ${dist_folder})
119119
endif ()
120+
if(USE_WEBGPU)
121+
file(COPY ${ONNXRUNTIME_WIN_BIN_DIR}/webgpu_dawn.dll
122+
DESTINATION ${dist_folder})
123+
file(COPY ${ONNXRUNTIME_WIN_BIN_DIR}/dxil.dll
124+
DESTINATION ${dist_folder})
125+
file(COPY ${ONNXRUNTIME_WIN_BIN_DIR}/dxcompiler.dll
126+
DESTINATION ${dist_folder})
127+
endif ()
120128
elseif (APPLE)
121129
file(COPY ${ONNXRUNTIME_BUILD_DIR}/libonnxruntime.dylib
122130
DESTINATION ${dist_folder} FOLLOW_SYMLINK_CHAIN)

js/node/src/directml_load_helper.cc

-37
This file was deleted.

js/node/src/directml_load_helper.h

-6
This file was deleted.

js/node/src/inference_session_wrap.cc

-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include "onnxruntime_cxx_api.h"
55

66
#include "common.h"
7-
#include "directml_load_helper.h"
87
#include "inference_session_wrap.h"
98
#include "run_options_helper.h"
109
#include "session_options_helper.h"
@@ -19,9 +18,6 @@ Napi::FunctionReference& InferenceSessionWrap::GetTensorConstructor() {
1918
}
2019

2120
Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) {
22-
#if defined(USE_DML) && defined(_WIN32)
23-
LoadDirectMLDll(env);
24-
#endif
2521
// create ONNX runtime env
2622
Ort::InitApi();
2723
ORT_NAPI_THROW_ERROR_IF(
+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
// == workaround for delay loading of dependencies of onnxruntime.dll ==
5+
//
6+
// Problem:
7+
//
8+
// When onnxruntime.dll uses delay loading for its dependencies, the dependencies are loaded using LoadLibraryEx,
9+
// which search the directory of process (.exe) instead of this library (onnxruntime.dll). This is a problem for
10+
// usages of Node.js binding and python binding, because Windows will try to find the dependencies in the directory
11+
// of node.exe or python.exe, which is not the directory of onnxruntime.dll.
12+
//
13+
// Solution:
14+
//
15+
// By using the delay load hook `__pfnDliNotifyHook2`, we can intervene the loading procedure by loading from an
16+
// absolute path. The absolute path is constructed by appending the name of the DLL to load to the directory of
17+
// onnxruntime.dll. This way, we can ensure that the dependencies are loaded from the same directory as onnxruntime.dll.
18+
//
19+
// See also:
20+
// - https://learn.microsoft.com/en-us/cpp/build/reference/understanding-the-helper-function?view=msvc-170#structure-and-constant-definitions
21+
// - https://learn.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-search-order#alternate-search-order-for-unpackaged-apps
22+
//
23+
// The DLL DelayLoad hook is only enabled when:
24+
// - The compiler is MSVC
25+
// - at least one of USE_WEBGPU or USE_DML is defined
26+
//
27+
#if defined(_MSC_VER) && (defined(USE_WEBGPU) || defined(USE_DML))
28+
29+
#include <Windows.h>
30+
#include <delayimp.h>
31+
#include <stdlib.h>
32+
#include <string>
33+
34+
namespace {
35+
36+
#define DEFINE_KNOWN_DLL(name) {#name ".dll", L#name L".dll"}
37+
38+
constexpr struct {
39+
const char* str;
40+
const wchar_t* wstr;
41+
} known_dlls[] = {
42+
#if defined(USE_WEBGPU)
43+
DEFINE_KNOWN_DLL(webgpu_dawn),
44+
#endif
45+
#if defined(USE_DML)
46+
DEFINE_KNOWN_DLL(DirectML),
47+
#endif
48+
};
49+
} // namespace
50+
51+
FARPROC WINAPI delay_load_hook(unsigned dliNotify, PDelayLoadInfo pdli) {
52+
if (dliNotify == dliNotePreLoadLibrary) {
53+
for (size_t i = 0; i < _countof(known_dlls); ++i) {
54+
if (_stricmp(pdli->szDll, known_dlls[i].str) == 0) {
55+
// Try to load the DLL from the same directory as onnxruntime.dll
56+
57+
// First, get the path to onnxruntime.dll
58+
DWORD pathLen = MAX_PATH;
59+
std::wstring path(pathLen, L'\0');
60+
HMODULE moduleHandle = nullptr;
61+
62+
GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
63+
reinterpret_cast<LPCWSTR>(&delay_load_hook), &moduleHandle);
64+
65+
DWORD getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast<wchar_t*>(path.c_str()), pathLen);
66+
while (getModuleFileNameResult == 0 || getModuleFileNameResult == pathLen) {
67+
int ret = GetLastError();
68+
if (ret == ERROR_INSUFFICIENT_BUFFER && pathLen < 32768) {
69+
pathLen *= 2;
70+
path.resize(pathLen);
71+
getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast<wchar_t*>(path.c_str()), pathLen);
72+
} else {
73+
// Failed to get the path to onnxruntime.dll. In this case, we will just return NULL and let the system
74+
// search for the DLL in the default search order.
75+
return NULL;
76+
}
77+
}
78+
79+
path.resize(path.rfind(L'\\') + 1);
80+
path.append(known_dlls[i].wstr);
81+
82+
return FARPROC(LoadLibraryExW(path.c_str(), NULL, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR));
83+
}
84+
}
85+
}
86+
return NULL;
87+
}
88+
89+
extern "C" const PfnDliHook __pfnDliNotifyHook2 = delay_load_hook;
90+
91+
#endif

onnxruntime/core/dll/dllmain.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#pragma GCC diagnostic pop
1414
#endif
1515

16-
// dllmain.cpp : Defines the entry point for the DLL application.
16+
// dllmain.cc : Defines the entry point for the DLL application.
1717
BOOL APIENTRY DllMain(HMODULE /*hModule*/,
1818
DWORD ul_reason_for_call,
1919
LPVOID /*lpReserved*/
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/dll_delay_load_helper.h"
5+
6+
#if defined(_WIN32) && defined(_MSC_VER) && !defined(__EMSCRIPTEN__)
7+
8+
#include <Windows.h>
9+
#include <delayimp.h>
10+
#include <stdlib.h>
11+
#include <string>
12+
#include <mutex>
13+
14+
namespace onnxruntime {
15+
namespace webgpu {
16+
17+
namespace {
18+
19+
// Get the directory of the current DLL (usually it's onnxruntime.dll).
20+
std::wstring GetCurrentDllDir() {
21+
DWORD pathLen = MAX_PATH;
22+
std::wstring path(pathLen, L'\0');
23+
HMODULE moduleHandle = nullptr;
24+
25+
GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
26+
reinterpret_cast<LPCWSTR>(&GetCurrentDllDir), &moduleHandle);
27+
28+
DWORD getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast<wchar_t*>(path.c_str()), pathLen);
29+
while (getModuleFileNameResult == 0 || getModuleFileNameResult == pathLen) {
30+
int ret = GetLastError();
31+
if (ret == ERROR_INSUFFICIENT_BUFFER && pathLen < 32768) {
32+
pathLen *= 2;
33+
path.resize(pathLen);
34+
getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast<wchar_t*>(path.c_str()), pathLen);
35+
} else {
36+
// Failed to get the path to onnxruntime.dll. Returns an empty string.
37+
return std::wstring{};
38+
}
39+
}
40+
path.resize(path.rfind(L'\\') + 1);
41+
return path;
42+
}
43+
44+
std::once_flag run_once_before_load_deps_mutex;
45+
std::once_flag run_once_after_load_deps_mutex;
46+
bool dll_dir_set = false;
47+
48+
} // namespace
49+
50+
DllDelayLoadHelper::DllDelayLoadHelper() {
51+
// Setup DLL search directory
52+
std::call_once(run_once_before_load_deps_mutex, []() {
53+
std::wstring path = GetCurrentDllDir();
54+
if (!path.empty()) {
55+
SetDllDirectoryW(path.c_str());
56+
dll_dir_set = true;
57+
}
58+
});
59+
}
60+
61+
DllDelayLoadHelper::~DllDelayLoadHelper() {
62+
// Restore DLL search directory
63+
std::call_once(run_once_after_load_deps_mutex, []() {
64+
if (dll_dir_set) {
65+
SetDllDirectoryW(NULL);
66+
}
67+
});
68+
}
69+
70+
} // namespace webgpu
71+
} // namespace onnxruntime
72+
73+
#else // defined(_WIN32) && defined(_MSC_VER) && !defined(__EMSCRIPTEN__)
74+
75+
namespace onnxruntime {
76+
namespace webgpu {
77+
78+
DllDelayLoadHelper::DllDelayLoadHelper() {
79+
}
80+
81+
DllDelayLoadHelper::~DllDelayLoadHelper() {
82+
}
83+
84+
} // namespace webgpu
85+
} // namespace onnxruntime
86+
87+
#endif
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
namespace onnxruntime {
7+
namespace webgpu {
8+
9+
// The DLL delay load helper is a RAII style guard to ensure DLL loading is done correctly.
10+
//
11+
// - On Windows, the helper sets the DLL search path to the directory of the current DLL.
12+
// - On other platforms, the helper does nothing.
13+
//
14+
struct DllDelayLoadHelper final {
15+
DllDelayLoadHelper();
16+
~DllDelayLoadHelper();
17+
};
18+
19+
} // namespace webgpu
20+
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/webgpu_context.cc

+5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "core/providers/webgpu/program_cache_key.h"
2020
#include "core/providers/webgpu/program_manager.h"
2121
#include "core/providers/webgpu/string_macros.h"
22+
#include "core/providers/webgpu/dll_delay_load_helper.h"
2223

2324
namespace onnxruntime {
2425
namespace webgpu {
@@ -50,6 +51,10 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info
5051

5152
// Initialization.Step.2 - Create wgpu::Adapter
5253
if (adapter_ == nullptr) {
54+
// DLL delay loading happens inside wgpuRequestAdapter().
55+
// Use this helper as RAII to ensure the DLL search path is set correctly.
56+
DllDelayLoadHelper helper{};
57+
5358
wgpu::RequestAdapterOptions req_adapter_options = {};
5459
wgpu::DawnTogglesDescriptor adapter_toggles_desc = {};
5560
req_adapter_options.nextInChain = &adapter_toggles_desc;

0 commit comments

Comments
 (0)