Skip to content

Commit ad63507

Browse files
PatriceVignolarachguo
authored and
rachguo
committed
[DML EP] Fix external data unpacking (#19415)
### Description This change 55a6694 didn't take into account external data when unpacking initializer, and therefore crashes when trying to unpack them.
1 parent 6e61306 commit ad63507

File tree

3 files changed

+20
-8
lines changed

3 files changed

+20
-8
lines changed

onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp

+10-5
Original file line numberDiff line numberDiff line change
@@ -344,20 +344,25 @@ namespace Dml::GraphDescBuilder
344344
dmlFusedNodeInputIndex < isConstGpuGraphInputCount &&
345345
isConstGpuGraphInput[dmlFusedNodeInputIndex])
346346
{
347-
// This is a highly inefficient approach to generating constant nodes. It duplicates constant data
348-
// across the graph input as well as every consumer's unique constant node. However it is currently
347+
// This is a highly inefficient approach to generating constant nodes. It duplicates constant data
348+
// across the graph input as well as every consumer's unique constant node. However it is currently
349349
// only used for small inputs.
350350
uint32_t c_maxConstNodeDataSize = 8;
351351

352-
ComPtr<OnnxTensorWrapper> constantInput = constantCpuGraphInputGetter(arg->Name());
353352

354353
auto& operatorGraphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex];
355354
std::vector<DmlBufferTensorDesc*> toNodeInputTensorDescs = operatorGraphInputNode->GetInputTensors();
356355
DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex];
356+
ComPtr<OnnxTensorWrapper> constantInput;
357357

358-
if (constantInput && tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize)
358+
if (tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize)
359359
{
360-
// The tensor description's size should be no larger than the constant input unless it was rounded to
360+
constantInput = constantCpuGraphInputGetter(arg->Name());
361+
}
362+
363+
if (constantInput)
364+
{
365+
// The tensor description's size should be no larger than the constant input unless it was rounded to
361366
// the required alignment.
362367
assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes);
363368
size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), gsl::narrow_cast<size_t>(tensorDesc->totalTensorSizeInBytes));

onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp

+9-3
Original file line numberDiff line numberDiff line change
@@ -1123,7 +1123,7 @@ namespace Windows::AI::MachineLearning::Adapter
11231123
}
11241124
ORT_CATCH_RETURN
11251125
}
1126-
1126+
11271127
template <class NodeInfoImpl_t, class Base1_t, class Base2_t>
11281128
HRESULT STDMETHODCALLTYPE OpNodeInfoWrapper<NodeInfoImpl_t, Base1_t, Base2_t>::GetConstantInputTensor(uint32_t inputIndex, IMLOperatorTensor** tensor) const noexcept
11291129
{
@@ -1168,7 +1168,7 @@ namespace Windows::AI::MachineLearning::Adapter
11681168
m_requiredConstantCpuInputs.begin(),
11691169
m_requiredConstantCpuInputs.end(),
11701170
inputIndex) != m_requiredConstantCpuInputs.end();
1171-
1171+
11721172
// This shouldn't happen since kernel creation is deferred and repeated when required constant inputs are not present.
11731173
ORT_THROW_HR_IF(E_UNEXPECTED, inputRequiredAsConstant);
11741174
}
@@ -1562,7 +1562,13 @@ namespace Windows::AI::MachineLearning::Adapter
15621562
OnnxTensorWrapper::OnnxTensorWrapper(onnx::TensorProto* impl, const onnxruntime::Path& modelPath) : m_impl(impl)
15631563
{
15641564
// The tensor may be stored as raw data or in typed fields.
1565-
if (impl->has_raw_data())
1565+
if (impl->data_location() == onnx::TensorProto_DataLocation_EXTERNAL)
1566+
{
1567+
THROW_IF_NOT_OK(onnxruntime::utils::UnpackInitializerData(*impl, modelPath, m_unpackedExternalTensor));
1568+
m_dataPtr = reinterpret_cast<std::byte*>(m_unpackedExternalTensor.data());
1569+
m_tensorByteSize = m_unpackedExternalTensor.size();
1570+
}
1571+
else if (impl->has_raw_data())
15661572
{
15671573
m_dataPtr = reinterpret_cast<std::byte*>(impl->mutable_raw_data()->data());
15681574
m_tensorByteSize = impl->raw_data().size();

onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h

+1
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ class OnnxTensorWrapper : public WRL::Base<IMLOperatorTensor>, public Closable
309309
private:
310310
size_t m_tensorByteSize = 0;
311311
std::unique_ptr<std::byte[]> m_unpackedTensor;
312+
std::vector<uint8_t> m_unpackedExternalTensor;
312313
std::byte* m_dataPtr = nullptr;
313314

314315
// Lifetime is managed by the caller and guaranteed to outlive this class

0 commit comments

Comments
 (0)