Skip to content

Commit e02b783

Browse files
PatriceVignolarachguo
authored and
rachguo
committed
Disable streams for the DML EP (#19481)
There's currently a bug in the allocation planner when reusing buffers and more than one streams are used that make it possible (although rarely) to reach a reference count of 0 for a buffer that is still being used. Since DML doesn't benefit from multiple streams, disabling it is the safest option for now. This is a high priority issue that we need to fix for 1.17.1 since it breaks stable diffusion. Identifying the perfect fix and fixing the underlying issue would be too risky for a patch release, especially given the limited time that we have. #19480
1 parent f5f5cc8 commit e02b783

File tree

4 files changed

+76
-11
lines changed

4 files changed

+76
-11
lines changed

cmake/adjust_global_compile_flags.cmake

+7-2
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,13 @@ if (onnxruntime_MINIMAL_BUILD)
9292
endif()
9393
endif()
9494

95-
# enable stream for all the non-minimal build
96-
if (NOT onnxruntime_MINIMAL_BUILD)
95+
# Enable stream for all the non-minimal build, except for DML. There's currently a bug
96+
# in the allocation planner when reusing buffers and more than one streams are used that
97+
# make it possible (although rarely) to reach a reference count of 0 for a buffer that is
98+
# still being used. Since DML doesn't benefit from multiple streams, disabling it is the
99+
# safest option for now.
100+
# https://github.com/microsoft/onnxruntime/issues/19480
101+
if (NOT onnxruntime_MINIMAL_BUILD AND NOT onnxruntime_USE_DML)
97102
add_compile_definitions(ORT_ENABLE_STREAM)
98103
endif()
99104

onnxruntime/test/framework/allocation_planner_test.cc

+17-4
Original file line numberDiff line numberDiff line change
@@ -327,10 +327,23 @@ class PlannerTest : public ::testing::Test {
327327

328328
if (invoke_createPlan_explicityly) {
329329
onnxruntime::GraphViewer graph_viewer{graph_};
330-
status = SequentialPlanner::CreatePlan(nullptr, graph_viewer, outer_scope_node_args, execution_providers_,
331-
kernel_create_info_map, {}, {}, state_->GetOrtValueNameIdxMap(), test_context,
332-
MockStreamHandleRegsitry(), /* {{kCpuExecutionProvider, 1}}, {},*/
333-
ORT_TSTR(""), DefaultLoggingManager().DefaultLogger(), plan_);
330+
status = SequentialPlanner::CreatePlan(
331+
nullptr,
332+
graph_viewer,
333+
outer_scope_node_args,
334+
execution_providers_,
335+
kernel_create_info_map,
336+
{},
337+
{},
338+
state_->GetOrtValueNameIdxMap(),
339+
test_context,
340+
#ifdef ORT_ENABLE_STREAM
341+
MockStreamHandleRegsitry(),
342+
#endif
343+
/* {{kCpuExecutionProvider, 1}}, {},*/
344+
ORT_TSTR(""),
345+
DefaultLoggingManager().DefaultLogger(),
346+
plan_);
334347

335348
EXPECT_TRUE(status.IsOK()) << status.ErrorMessage();
336349
// AllocationPlanTestUtility::BasicIntegrityCheck(*plan_, name_to_arg_.size());

onnxruntime/test/framework/bfc_arena_test.cc

+2
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,7 @@ struct StreamMock : public Stream {
337337
Status CleanUpOnRunEnd() override { return Status::OK(); }
338338
};
339339

340+
#ifdef ORT_ENABLE_STREAM
340341
TEST(StreamAwareArenaTest, TwoStreamAllocation) {
341342
StreamAwareArena a(std::unique_ptr<IAllocator>(new CPUAllocator()), 1 << 30, false);
342343
CheckStats(&a, 0, 0, 0, 0);
@@ -413,6 +414,7 @@ TEST(StreamAwareArenaTest, TestSecureTheChunk) {
413414
EXPECT_TRUE(waitFunctionInvoked) << "wait function should be invoked";
414415
a.Free(p2);
415416
}
417+
#endif
416418

417419
TEST(BFCArenaTest, TestExtendStrategy) {
418420
int64_t extend_delta_bytes = 0;

onnxruntime/test/framework/execution_frame_test.cc

+50-5
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,16 @@ TEST_F(ExecutionFrameTest, TensorAllocationTest) {
7575
ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager));
7676

7777
vector<OrtValue> outputs;
78-
ExecutionFrame frame({}, {}, {}, outputs, {}, {}, state);
78+
ExecutionFrame frame(
79+
{},
80+
{},
81+
{},
82+
outputs,
83+
{},
84+
#ifdef ORT_ENABLE_STREAM
85+
{},
86+
#endif
87+
state);
7988

8089
int start_index = frame.GetNodeOffset(node->Index());
8190
ASSERT_EQ(start_index, 0);
@@ -150,7 +159,16 @@ TEST_F(ExecutionFrameTest, OutputShapeValidationTest) {
150159
ASSERT_STATUS_OK(state.FinalizeSessionState(ORT_TSTR(""), kernel_registry_manager));
151160

152161
vector<OrtValue> outputs;
153-
ExecutionFrame frame({}, {}, {}, outputs, {}, {}, state);
162+
ExecutionFrame frame(
163+
{},
164+
{},
165+
{},
166+
outputs,
167+
{},
168+
#ifdef ORT_ENABLE_STREAM
169+
{},
170+
#endif
171+
state);
154172

155173
int start_index = frame.GetNodeOffset(node->Index());
156174
ASSERT_EQ(start_index, 0);
@@ -216,7 +234,16 @@ TEST_F(ExecutionFrameTest, FeedInDataTest) {
216234
ASSERT_TRUE(mlvalue_name_idx_map.GetIdx("Y", y_idx).IsOK());
217235

218236
vector<OrtValue> outputs;
219-
ExecutionFrame frame(AsSpan({x_idx}), AsSpan({value}), AsSpan({y_idx}), outputs, {}, {}, state);
237+
ExecutionFrame frame(
238+
AsSpan({x_idx}),
239+
AsSpan({value}),
240+
AsSpan({y_idx}),
241+
outputs,
242+
{},
243+
#ifdef ORT_ENABLE_STREAM
244+
{},
245+
#endif
246+
state);
220247

221248
OrtValue* p_ml_value = frame.GetMutableNodeInputOrOutputMLValue(0);
222249
Tensor* p_tensor_arg_0 = p_ml_value ? p_ml_value->GetMutable<Tensor>() : nullptr;
@@ -299,7 +326,16 @@ TEST_F(ExecutionFrameTest, MemPatternTest) {
299326
std::vector<float>(6, 1.0f), &v3);
300327

301328
std::vector<OrtValue> outputs;
302-
ExecutionFrame frame(AsSpan({x1_idx, x2_idx, x3_idx}), AsSpan({v1, v2, v3}), AsSpan({t3_idx}), outputs, {}, {}, state);
329+
ExecutionFrame frame(
330+
AsSpan({x1_idx, x2_idx, x3_idx}),
331+
AsSpan({v1, v2, v3}),
332+
AsSpan({t3_idx}),
333+
outputs,
334+
{},
335+
#ifdef ORT_ENABLE_STREAM
336+
{},
337+
#endif
338+
state);
303339

304340
OrtValue& mlvalue3 = *frame.GetMutableNodeInputOrOutputMLValue(3);
305341
OrtValue& mlvalue4 = *frame.GetMutableNodeInputOrOutputMLValue(4);
@@ -388,7 +424,16 @@ TEST_F(ExecutionFrameTest, MemPatternWithExternalOutputsTest) {
388424
CreateMLValue<float>(cpu_allocator, std::vector<int64_t>{2, 2}, std::vector<float>(4, 1.0f), &t_value);
389425

390426
vector<OrtValue> outputs;
391-
ExecutionFrame frame(AsSpan({x_idx}), AsSpan({x_value}), AsSpan({y_idx}), outputs, {}, {}, state);
427+
ExecutionFrame frame(
428+
AsSpan({x_idx}),
429+
AsSpan({x_value}),
430+
AsSpan({y_idx}),
431+
outputs,
432+
{},
433+
#ifdef ORT_ENABLE_STREAM
434+
{},
435+
#endif
436+
state);
392437

393438
ASSERT_FALSE(frame.GetMutableNodeInputOrOutputMLValue(t_idx)->IsTensor());
394439
ASSERT_STATUS_OK(frame.SetOutputMLValue(t_idx, t_value));

0 commit comments

Comments
 (0)