Skip to content

Commit 00ae32d

Browse files
satyajandhyalaguschmue
authored andcommitted
[WebGPU-EP Native] Add ReduceMean (#23860)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent 88d6af5 commit 00ae32d

File tree

3 files changed

+234
-4
lines changed

3 files changed

+234
-4
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "core/providers/webgpu/reduction/reduction_ops.h"
5+
#include <sstream>
6+
#include "core/framework/data_transfer_manager.h"
7+
#include "core/providers/webgpu/data_transfer.h"
8+
#include "core/providers/webgpu/shader_helper.h"
9+
#include "core/providers/webgpu/webgpu_supported_types.h"
10+
11+
namespace onnxruntime {
12+
namespace webgpu {
13+
14+
#define REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceOp, begin, end) \
15+
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
16+
ReduceOp, \
17+
kOnnxDomain, \
18+
begin, end, \
19+
kWebGpuExecutionProvider, \
20+
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedNumberTypes()), \
21+
ReduceOp);
22+
23+
#define REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceOp, version) \
24+
ONNX_OPERATOR_KERNEL_EX( \
25+
ReduceOp, \
26+
kOnnxDomain, \
27+
version, \
28+
kWebGpuExecutionProvider, \
29+
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedNumberTypes()).InputMemoryType(OrtMemTypeCPUInput, 1), \
30+
ReduceOp);
31+
32+
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 1, 10);
33+
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 11, 12);
34+
REGISTER_UNARY_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 13, 17);
35+
REGISTER_UNARY_ELEMENTWISE_KERNEL(ReduceMean, 18);
36+
37+
Status ReduceKernelProgram::GenerateShaderCode(ShaderHelper& shader) const {
38+
const auto& input = shader.AddInput("input", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
39+
const auto& output = shader.AddOutput("output", ShaderUsage::UseUniform | ShaderUsage::UseIndicesTypeAlias | ShaderUsage::UseValueTypeAlias);
40+
bool reduce_on_all_axes = no_op_with_empty_axes_ == false && axes_.empty();
41+
std::string loop_header = code_[0];
42+
std::string loop_body = "let current_element: input_value_t = " + input.GetByIndices("input_indices") + ";\n" + code_[1];
43+
std::string loop_footer = code_[2];
44+
const auto input_rank = input.Rank();
45+
for (int i = 0, l = 0; i < input_rank; ++i) {
46+
if (reduce_on_all_axes || std::find(axes_.begin(), axes_.end(), i) != axes_.end()) {
47+
if (keepdims_) {
48+
l++;
49+
}
50+
std::stringstream ss;
51+
std::string index = "i" + std::to_string(i);
52+
ss << "for (var " << index << " : u32 = 0; " << index << " < " << input.IndicesGet("uniforms.input_shape", i) << "; " << index << "++) {\n";
53+
ss << input.IndicesSet("input_indices", i, index) << ";\n";
54+
ss << loop_body << "\n";
55+
ss << "}\n";
56+
loop_body = ss.str();
57+
} else {
58+
std::stringstream ss;
59+
ss << loop_header << "\n";
60+
std::string index = "i" + std::to_string(i);
61+
ss << "let " << index << " = " << output.IndicesGet("output_indices", l) << ";\n";
62+
ss << input.IndicesSet("input_indices", i, index) << ";\n";
63+
loop_header = ss.str();
64+
l++;
65+
}
66+
}
67+
std::stringstream input_indices_init_value;
68+
for (int i = 0; i < input_rank - 1; ++i) {
69+
input_indices_init_value << "0, ";
70+
}
71+
input_indices_init_value << "0";
72+
shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
73+
<< "let output_indices: output_indices_t = " << output.OffsetToIndices("global_idx") << ";\n"
74+
<< "var input_indices: input_indices_t = input_indices_t(" << input_indices_init_value.str() << ");\n"
75+
<< loop_header << loop_body << loop_footer;
76+
shader.MainFunctionBody() << output.SetByOffset("global_idx", "output_value");
77+
return Status::OK();
78+
}
79+
80+
template <bool allow_multi_axes>
81+
Status ReduceKernel<allow_multi_axes>::ComputeInternal(ComputeContext& context) const {
82+
const auto* input_tensor = context.Input(0);
83+
InlinedVector<uint32_t> input_axes;
84+
auto rank = input_tensor->Shape().NumDimensions();
85+
auto transform_axis = [rank](int64_t axis) {
86+
if (axis < 0) {
87+
axis += rank;
88+
}
89+
if (axis < 0 || static_cast<size_t>(axis) >= rank) {
90+
ORT_THROW("Axes values must be in the range [-rank, rank-1]. Got: ", axis);
91+
}
92+
return static_cast<uint32_t>(axis);
93+
};
94+
// Check if axes input is provided and copy the axes values to input_axes
95+
if (context.InputCount() > 1) {
96+
ORT_ENFORCE(axes_.empty(), "Axes attribute may not be specified when axes input is also provided.");
97+
const Tensor* axes_tensor = context.Input<Tensor>(1);
98+
auto size = static_cast<size_t>(axes_tensor->Shape()[0]);
99+
const auto* data = axes_tensor->Data<int64_t>();
100+
input_axes.reserve(size);
101+
std::transform(data, data + size, std::back_inserter(input_axes), transform_axis);
102+
} else {
103+
input_axes.reserve(axes_.size());
104+
std::transform(axes_.begin(), axes_.end(), std::back_inserter(input_axes), transform_axis);
105+
}
106+
if (input_axes.empty()) {
107+
if (noop_with_empty_axes_ || rank == 0) {
108+
// If axes is empty and noop_with_empty_axes_ is true, it is a no-op according to the spec
109+
// If input tensor is a scalar, return the input tensor as is.
110+
// This is not correct for ReduceLogSum and ReduceSumSquare
111+
// TODO handle these cases separately.
112+
auto output = context.Output(0, input_tensor->Shape());
113+
if (output->DataRaw() != input_tensor->DataRaw()) {
114+
ORT_RETURN_IF_ERROR(Info().GetDataTransferManager().CopyTensor(*input_tensor, *output));
115+
}
116+
return Status::OK();
117+
} else {
118+
// If axes is empty and noop_with_empty_axes_ is false, it is a reduction over all axes
119+
input_axes.resize(rank);
120+
std::iota(input_axes.begin(), input_axes.end(), 0);
121+
}
122+
}
123+
const auto code = GetOpSpecificCode(input_tensor, input_axes.size());
124+
// Compute output shape
125+
std::vector<int64_t> output_shape;
126+
for (size_t i = 0; i < input_tensor->Shape().NumDimensions(); ++i) {
127+
if (std::find(input_axes.begin(), input_axes.end(), i) != input_axes.end()) {
128+
if (keepdims_) {
129+
output_shape.push_back(1);
130+
}
131+
} else {
132+
output_shape.push_back(input_tensor->Shape()[i]);
133+
}
134+
}
135+
TensorShape output_tensor_shape(output_shape);
136+
int64_t output_size = output_tensor_shape.Size();
137+
ReduceKernelProgram program("ReduceMean", keepdims_, noop_with_empty_axes_, input_axes, code);
138+
program.AddInput({input_tensor, ProgramTensorMetadataDependency::TypeAndRank})
139+
.AddOutput({context.Output(0, output_shape), ProgramTensorMetadataDependency::TypeAndRank})
140+
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
141+
.AddUniformVariables({{static_cast<uint32_t>(output_size)},
142+
{static_cast<uint32_t>(noop_with_empty_axes_ ? 1 : 0)},
143+
{input_axes},
144+
{static_cast<uint32_t>(input_axes.size())}});
145+
146+
return context.RunProgram(program);
147+
}
148+
149+
ReduceOpSpecificCode ReduceMean::GetOpSpecificCode(const Tensor* input_tensor, size_t axes_size) const {
150+
const TensorShape& input_shape = input_tensor->Shape();
151+
size_t input_rank = input_shape.NumDimensions();
152+
std::stringstream ss;
153+
ss << "var size: u32 = 1;\n"
154+
<< "for (var i: u32 = 0; i < uniforms.axes_size; i += 1) { \n"
155+
<< " let index = " << GetElementAt("uniforms.axes", "i", axes_size) << ";\n"
156+
<< " size = size * " << GetElementAt("uniforms.input_shape", "index", input_rank) << ";\n"
157+
<< "}\n"
158+
<< "let output_value = output_value_t(sum / f32(size));";
159+
ReduceOpSpecificCode code({"var sum = f32(0);", "sum += f32(current_element);", ss.str()});
160+
return code;
161+
}
162+
163+
Status ReduceMean::ComputeInternal(ComputeContext& ctx) const {
164+
return ReduceKernel<true>::ComputeInternal(ctx);
165+
}
166+
167+
} // namespace webgpu
168+
} // namespace onnxruntime
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
#include "core/common/optional.h"
6+
#include "core/providers/webgpu/webgpu_supported_types.h"
7+
#include "core/providers/webgpu/webgpu_kernel.h"
8+
#include "core/providers/cpu/reduction/reduction_kernel_base.h"
9+
#include "core/providers/webgpu/program.h"
10+
#include "core/providers/webgpu/shader_helper.h"
11+
namespace onnxruntime {
12+
namespace webgpu {
13+
// reduceOpSpecificCode is a 3-element array of strings that represent the op specific code for the reduce operation.
14+
// The first element is the loop header, the second element is the loop body, and the third element is the loop footer.
15+
// The loop header is the code that is executed before the loop starts. The loop body is the code that is executed for each element in the loop.
16+
// The loop footer is the code that is executed after the loop ends.
17+
typedef std::array<std::string, 3> ReduceOpSpecificCode;
18+
class ReduceKernelProgram final : public Program<ReduceKernelProgram> {
19+
public:
20+
ReduceKernelProgram(std::string name, bool keepdims, bool no_op_with_empty_axes, const InlinedVector<uint32_t>& axes, ReduceOpSpecificCode code) : Program{name}, keepdims_(keepdims), no_op_with_empty_axes_(no_op_with_empty_axes), axes_(axes.begin(), axes.end()), code_(code) {}
21+
Status GenerateShaderCode(ShaderHelper& wgpuShaderModuleAddRef) const override;
22+
WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32},
23+
{"no_op_with_empty_axes", ProgramUniformVariableDataType::Uint32},
24+
{"axes", ProgramUniformVariableDataType::Uint32},
25+
{"axes_size", ProgramUniformVariableDataType::Uint32});
26+
27+
private:
28+
const bool keepdims_;
29+
const bool no_op_with_empty_axes_;
30+
InlinedVector<uint32_t> axes_;
31+
ReduceOpSpecificCode code_;
32+
};
33+
34+
template <bool allow_multi_axes = true>
35+
class ReduceKernel : public WebGpuKernel, public ReduceKernelBase<allow_multi_axes> {
36+
protected:
37+
using ReduceKernelBase<allow_multi_axes>::axes_;
38+
using ReduceKernelBase<allow_multi_axes>::noop_with_empty_axes_;
39+
using ReduceKernelBase<allow_multi_axes>::keepdims_;
40+
using ReduceKernelBase<allow_multi_axes>::select_last_index_;
41+
42+
ReduceKernel(const OpKernelInfo& info, std::string name, optional<int64_t> keepdims_override = {})
43+
: WebGpuKernel(info),
44+
ReduceKernelBase<allow_multi_axes>(info, keepdims_override),
45+
name_(name) {
46+
}
47+
Status ComputeInternal(ComputeContext& ctx) const;
48+
virtual ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor, size_t axes_size) const = 0;
49+
50+
private:
51+
std::string name_;
52+
};
53+
54+
class ReduceMean final : public ReduceKernel<true> {
55+
public:
56+
ReduceMean(const OpKernelInfo& info) : ReduceKernel<true>(info, "ReduceMean") {}
57+
ReduceOpSpecificCode GetOpSpecificCode(const Tensor* input_tensor, size_t axes_size) const override;
58+
Status ComputeInternal(ComputeContext& ctx) const override;
59+
};
60+
61+
} // namespace webgpu
62+
} // namespace onnxruntime

onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -516,10 +516,10 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
516516
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMax)>,
517517
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMax)>,
518518

519-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMean)>,
520-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceMean)>,
521-
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMean)>,
522-
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMean)>,
519+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, ReduceMean)>,
520+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, ReduceMean)>,
521+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, 17, ReduceMean)>,
522+
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 18, ReduceMean)>,
523523

524524
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 10, Unsqueeze)>,
525525
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze)>,

0 commit comments

Comments
 (0)