@@ -29,7 +29,7 @@ class GemmOpBuilder : public BaseOpBuilder {
29
29
30
30
// Add operator related.
31
31
Status GemmOpBuilder::AddToModelBuilderImpl (ModelBuilder& model_builder, const Node& node,
32
- const logging::Logger& /* logger */ ) const {
32
+ const logging::Logger& logger) const {
33
33
const auto & op_type = node.OpType ();
34
34
const auto & input_defs = node.InputDefs ();
35
35
const size_t a_idx = 0 , b_idx = 1 , c_idx = 2 ; // A*B+C
@@ -38,7 +38,17 @@ Status GemmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
38
38
emscripten::val b = model_builder.GetOperand (node.InputDefs ()[b_idx]->Name ());
39
39
emscripten::val output = emscripten::val::object ();
40
40
if (op_type == " MatMul" ) {
41
- output = model_builder.GetBuilder ().call <emscripten::val>(" matmul" , a, b);
41
+ std::vector<int64_t > a_shape;
42
+ if (!GetShape (*input_defs[a_idx], a_shape, logger)) {
43
+ return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT, " Can not get shape of A." );
44
+ }
45
+ // The inputs of MatMul must be at least 3D for WebNN CPU backend. Use GEMM for 2D case.
46
+ // TODO: Remove this workaround when it is fixed in Chromium.
47
+ if (model_builder.GetWebnnDeviceType () == WebnnDeviceType::CPU && a_shape.size () == 2 ) {
48
+ output = model_builder.GetBuilder ().call <emscripten::val>(" gemm" , a, b);
49
+ } else {
50
+ output = model_builder.GetBuilder ().call <emscripten::val>(" matmul" , a, b);
51
+ }
42
52
} else if (op_type == " MatMulInteger" ) {
43
53
emscripten::val a_zero_point = emscripten::val::null ();
44
54
emscripten::val b_zero_point = emscripten::val::null ();
0 commit comments