Skip to content

Commit 5c346be

Browse files
committed
Fix delay load for WebGPU EP and DML EP
1 parent 3a0b958 commit 5c346be

14 files changed

+367
-49
lines changed

cmake/onnxruntime.cmake

+1
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ if(WIN32)
7777
onnxruntime_add_shared_library(onnxruntime
7878
${SYMBOL_FILE}
7979
"${ONNXRUNTIME_ROOT}/core/dll/dllmain.cc"
80+
"${ONNXRUNTIME_ROOT}/core/dll/delay_load_hook.cc"
8081
"${ONNXRUNTIME_ROOT}/core/dll/onnxruntime.rc"
8182
)
8283
elseif(onnxruntime_BUILD_APPLE_FRAMEWORK)

cmake/onnxruntime_nodejs.cmake

+4
Original file line numberDiff line numberDiff line change
@@ -103,4 +103,8 @@ add_custom_target(nodejs_binding_wrapper ALL
103103
add_dependencies(js_common_npm_ci js_npm_ci)
104104
add_dependencies(nodejs_binding_wrapper js_common_npm_ci)
105105
add_dependencies(nodejs_binding_wrapper onnxruntime)
106+
if (WIN32 AND onnxruntime_USE_WEBGPU)
107+
add_dependencies(nodejs_binding_wrapper copy_dxil_dll)
108+
add_dependencies(nodejs_binding_wrapper dxcompiler)
109+
endif()
106110
endif()

cmake/onnxruntime_unittests.cmake

+12
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,9 @@ set (onnxruntime_global_thread_pools_test_SRC
525525
set (onnxruntime_webgpu_external_dawn_test_SRC
526526
${TEST_SRC_DIR}/webgpu/external_dawn/main.cc)
527527

528+
set (onnxruntime_webgpu_delay_load_test_SRC
529+
${TEST_SRC_DIR}/webgpu/delay_load/main.cc)
530+
528531
# tests from lowest level library up.
529532
# the order of libraries should be maintained, with higher libraries being added first in the list
530533

@@ -1864,4 +1867,13 @@ if (onnxruntime_USE_WEBGPU AND onnxruntime_USE_EXTERNAL_DAWN)
18641867
onnxruntime_add_include_to_target(onnxruntime_webgpu_external_dawn_test dawn::dawncpp_headers dawn::dawn_headers)
18651868
endif()
18661869

1870+
if (onnxruntime_USE_WEBGPU AND WIN32 AND onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT onnxruntime_MINIMAL_BUILD)
1871+
AddTest(DYN
1872+
TARGET onnxruntime_webgpu_delay_load_test
1873+
SOURCES ${onnxruntime_webgpu_delay_load_test_SRC}
1874+
LIBS ${SYS_PATH_LIB}
1875+
DEPENDS ${all_dependencies}
1876+
)
1877+
endif()
1878+
18671879
include(onnxruntime_fuzz_test.cmake)

js/node/CMakeLists.txt

+8
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,14 @@ if (WIN32)
117117
file(COPY ${ONNXRUNTIME_WIN_BIN_DIR}/DirectML.dll
118118
DESTINATION ${dist_folder})
119119
endif ()
120+
if(USE_WEBGPU)
121+
file(COPY ${ONNXRUNTIME_WIN_BIN_DIR}/webgpu_dawn.dll
122+
DESTINATION ${dist_folder})
123+
file(COPY ${ONNXRUNTIME_WIN_BIN_DIR}/dxil.dll
124+
DESTINATION ${dist_folder})
125+
file(COPY ${ONNXRUNTIME_WIN_BIN_DIR}/dxcompiler.dll
126+
DESTINATION ${dist_folder})
127+
endif ()
120128
elseif (APPLE)
121129
file(COPY ${ONNXRUNTIME_BUILD_DIR}/libonnxruntime.dylib
122130
DESTINATION ${dist_folder} FOLLOW_SYMLINK_CHAIN)

js/node/src/directml_load_helper.cc

-37
This file was deleted.

js/node/src/directml_load_helper.h

-6
This file was deleted.

js/node/src/inference_session_wrap.cc

-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
#include "onnxruntime_cxx_api.h"
55

66
#include "common.h"
7-
#include "directml_load_helper.h"
87
#include "inference_session_wrap.h"
98
#include "run_options_helper.h"
109
#include "session_options_helper.h"
@@ -19,9 +18,6 @@ Napi::FunctionReference& InferenceSessionWrap::GetTensorConstructor() {
1918
}
2019

2120
Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) {
22-
#if defined(USE_DML) && defined(_WIN32)
23-
LoadDirectMLDll(env);
24-
#endif
2521
// create ONNX runtime env
2622
Ort::InitApi();
2723
ORT_NAPI_THROW_ERROR_IF(
+89
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
// == workaround for delay loading of dependencies of onnxruntime.dll ==
5+
//
6+
// Problem:
7+
//
8+
// When onnxruntime.dll uses delay loading for its dependencies, the dependencies are loaded using LoadLibraryEx,
9+
// which search the directory of process (.exe) instead of this library (onnxruntime.dll). This is a problem for
10+
// usages of Node.js binding and python binding, because Windows will try to find the dependencies in the directory
11+
// of node.exe or python.exe, which is not the directory of onnxruntime.dll.
12+
//
13+
// Solution:
14+
//
15+
// By using Win32 API `SetDefaultDllDirectories` and `AddDllDirectory`, we can modify the DLL search order to include
16+
// the directory of onnxruntime.dll. This will make sure the dependencies are loaded from the directory of onnxruntime.dll
17+
// when later calling LoadLibraryEx() without flags.
18+
//
19+
// See https://learn.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-search-order#alternate-search-order-for-unpackaged-apps
20+
//
21+
// The DLL DelayLoad hook is only enabled when:
22+
// - The compiler is MSVC
23+
// - at least one of USE_WEBGPU or USE_DML is defined
24+
//
25+
#if defined(_MSC_VER) && (defined(USE_WEBGPU) || defined(USE_DML))
26+
27+
#include <Windows.h>
28+
#include <delayimp.h>
29+
#include <stdlib.h>
30+
#include <string>
31+
32+
namespace {
33+
34+
#define DEFINE_KNOWN_DLL(name) {#name ".dll", L#name L".dll"}
35+
36+
constexpr struct {
37+
const char* str;
38+
const wchar_t* wstr;
39+
} known_dlls[] = {
40+
#if defined(USE_WEBGPU)
41+
DEFINE_KNOWN_DLL(webgpu_dawn),
42+
#endif
43+
#if defined(USE_DML)
44+
DEFINE_KNOWN_DLL(DirectML),
45+
#endif
46+
};
47+
} // namespace
48+
49+
FARPROC WINAPI delay_load_hook(unsigned dliNotify, PDelayLoadInfo pdli) {
50+
if (dliNotify == dliNotePreLoadLibrary) {
51+
for (size_t i = 0; i < _countof(known_dlls); ++i) {
52+
if (_stricmp(pdli->szDll, known_dlls[i].str) == 0) {
53+
// Try to load the DLL from the same directory as onnxruntime.dll
54+
55+
// First, get the path to onnxruntime.dll
56+
DWORD pathLen = MAX_PATH;
57+
std::wstring path(pathLen, L'\0');
58+
HMODULE moduleHandle = nullptr;
59+
60+
GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
61+
reinterpret_cast<LPCWSTR>(&delay_load_hook), &moduleHandle);
62+
63+
DWORD getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast<wchar_t*>(path.c_str()), pathLen);
64+
while (getModuleFileNameResult == 0 || getModuleFileNameResult == pathLen) {
65+
int ret = GetLastError();
66+
if (ret == ERROR_INSUFFICIENT_BUFFER && pathLen < 32768) {
67+
pathLen *= 2;
68+
path.resize(pathLen);
69+
getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast<wchar_t*>(path.c_str()), pathLen);
70+
} else {
71+
// Failed to get the path to onnxruntime.dll. In this case, we will just return NULL and let the system
72+
// search for the DLL in the default search order.
73+
return NULL;
74+
}
75+
}
76+
77+
path.resize(path.rfind(L'\\') + 1);
78+
path.append(known_dlls[i].wstr);
79+
80+
return FARPROC(LoadLibraryExW(path.c_str(), NULL, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR));
81+
}
82+
}
83+
}
84+
return NULL;
85+
}
86+
87+
extern "C" const PfnDliHook __pfnDliNotifyHook2 = delay_load_hook;
88+
89+
#endif

onnxruntime/core/dll/dllmain.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#pragma GCC diagnostic pop
1414
#endif
1515

16-
// dllmain.cpp : Defines the entry point for the DLL application.
16+
// dllmain.cc : Defines the entry point for the DLL application.
1717
BOOL APIENTRY DllMain(HMODULE /*hModule*/,
1818
DWORD ul_reason_for_call,
1919
LPVOID /*lpReserved*/
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/dll_delay_load_helper.h"
5+
6+
#if defined(_WIN32) && defined(_MSC_VER) && !defined(__EMSCRIPTEN__)
7+
8+
#include <Windows.h>
9+
#include <delayimp.h>
10+
#include <stdlib.h>
11+
#include <string>
12+
#include <mutex>
13+
14+
namespace onnxruntime {
15+
namespace webgpu {
16+
17+
namespace {
18+
19+
// Get the directory of the current DLL (usually it's onnxruntime.dll).
20+
std::wstring GetCurrentDllDir() {
21+
DWORD pathLen = MAX_PATH;
22+
std::wstring path(pathLen, L'\0');
23+
HMODULE moduleHandle = nullptr;
24+
25+
GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
26+
reinterpret_cast<LPCWSTR>(&GetCurrentDllDir), &moduleHandle);
27+
28+
DWORD getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast<wchar_t*>(path.c_str()), pathLen);
29+
while (getModuleFileNameResult == 0 || getModuleFileNameResult == pathLen) {
30+
int ret = GetLastError();
31+
if (ret == ERROR_INSUFFICIENT_BUFFER && pathLen < 32768) {
32+
pathLen *= 2;
33+
path.resize(pathLen);
34+
getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast<wchar_t*>(path.c_str()), pathLen);
35+
} else {
36+
// Failed to get the path to onnxruntime.dll. Returns an empty string.
37+
return std::wstring{};
38+
}
39+
}
40+
path.resize(path.rfind(L'\\') + 1);
41+
return path;
42+
}
43+
44+
std::once_flag run_once_before_load_deps_mutex;
45+
std::once_flag run_once_after_load_deps_mutex;
46+
bool dll_dir_set = false;
47+
48+
} // namespace
49+
50+
DllDelayLoadHelper::DllDelayLoadHelper() {
51+
// Setup DLL search directory
52+
std::call_once(run_once_before_load_deps_mutex, []() {
53+
std::wstring path = GetCurrentDllDir();
54+
if (!path.empty()) {
55+
SetDllDirectoryW(path.c_str());
56+
dll_dir_set = true;
57+
}
58+
});
59+
}
60+
61+
DllDelayLoadHelper::~DllDelayLoadHelper() {
62+
// Restore DLL search directory
63+
std::call_once(run_once_after_load_deps_mutex, []() {
64+
if (dll_dir_set) {
65+
SetDllDirectoryW(NULL);
66+
}
67+
});
68+
}
69+
70+
} // namespace webgpu
71+
} // namespace onnxruntime
72+
73+
#else // defined(_WIN32) && defined(_MSC_VER) && !defined(__EMSCRIPTEN__)
74+
75+
namespace onnxruntime {
76+
namespace webgpu {
77+
78+
DllDelayLoadHelper::DllDelayLoadHelper() {
79+
}
80+
81+
DllDelayLoadHelper::~DllDelayLoadHelper() {
82+
}
83+
84+
} // namespace webgpu
85+
} // namespace onnxruntime
86+
87+
#endif
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
namespace onnxruntime {
7+
namespace webgpu {
8+
9+
// The DLL delay load helper is a RAII style guard to ensure DLL loading is done correctly.
10+
//
11+
// - On Windows, the helper sets the DLL search path to the directory of the current DLL.
12+
// - On other platforms, the helper does nothing.
13+
//
14+
struct DllDelayLoadHelper final {
15+
DllDelayLoadHelper();
16+
~DllDelayLoadHelper();
17+
};
18+
19+
} // namespace webgpu
20+
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/webgpu_context.cc

+5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "core/providers/webgpu/program_cache_key.h"
2020
#include "core/providers/webgpu/program_manager.h"
2121
#include "core/providers/webgpu/string_macros.h"
22+
#include "core/providers/webgpu/dll_delay_load_helper.h"
2223

2324
namespace onnxruntime {
2425
namespace webgpu {
@@ -50,6 +51,10 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info
5051

5152
// Initialization.Step.2 - Create wgpu::Adapter
5253
if (adapter_ == nullptr) {
54+
// DLL delay loading happens inside wgpuRequestAdapter().
55+
// Use this helper as RAII to ensure the DLL search path is set correctly.
56+
DllDelayLoadHelper helper{};
57+
5358
wgpu::RequestAdapterOptions req_adapter_options = {};
5459
wgpu::DawnTogglesDescriptor adapter_toggles_desc = {};
5560
req_adapter_options.nextInChain = &adapter_toggles_desc;

0 commit comments

Comments
 (0)