Skip to content

Commit 2e4d1b8

Browse files
author
zesongw
authored
[WebNN EP] Add support for Op MatMul of WebNN CPU backend (#19413)
Enable MatMul support for WebNN CPU backend to support more models.
1 parent 1c468a0 commit 2e4d1b8

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

Diff for: onnxruntime/core/providers/webnn/builders/helper.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
195195
{"LessOrEqual", {"lesserOrEqual", false}},
196196
{"Log", {"log", false}},
197197
{"LpPool", {"l2Pool2d", false}},
198-
{"MatMul", {"matmul", false}},
198+
{"MatMul", {"matmul", true}},
199199
{"MatMulInteger", {"matmulInteger", false}},
200200
{"Max", {"max", true}},
201201
{"MaxPool", {"maxPool2d", true}},

Diff for: onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc

+12-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class GemmOpBuilder : public BaseOpBuilder {
2929

3030
// Add operator related.
3131
Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
32-
const logging::Logger& /* logger */) const {
32+
const logging::Logger& logger) const {
3333
const auto& op_type = node.OpType();
3434
const auto& input_defs = node.InputDefs();
3535
const size_t a_idx = 0, b_idx = 1, c_idx = 2; // A*B+C
@@ -38,7 +38,17 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
3838
emscripten::val b = model_builder.GetOperand(node.InputDefs()[b_idx]->Name());
3939
emscripten::val output = emscripten::val::object();
4040
if (op_type == "MatMul") {
41-
output = model_builder.GetBuilder().call<emscripten::val>("matmul", a, b);
41+
std::vector<int64_t> a_shape;
42+
if (!GetShape(*input_defs[a_idx], a_shape, logger)) {
43+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Can not get shape of A.");
44+
}
45+
// The inputs of MatMul must be at least 3D for WebNN CPU backend. Use GEMM for 2D case.
46+
// TODO: Remove this workaround when it is fixed in Chromium.
47+
if (model_builder.GetWebnnDeviceType() == WebnnDeviceType::CPU && a_shape.size() == 2) {
48+
output = model_builder.GetBuilder().call<emscripten::val>("gemm", a, b);
49+
} else {
50+
output = model_builder.GetBuilder().call<emscripten::val>("matmul", a, b);
51+
}
4252
} else if (op_type == "MatMulInteger") {
4353
emscripten::val a_zero_point = emscripten::val::null();
4454
emscripten::val b_zero_point = emscripten::val::null();

0 commit comments

Comments
 (0)