Skip to content

Commit 2e0a388

Browse files
authored
[js/webgpu] Add HardSigmoid support (#19215)
### Description This op is required in mobilenetv3-small-100. With this PR, mobilenetv3-small-100 model becomes less than 10 ms from over 100 ms on ADL.
1 parent e283cdb commit 2e0a388

File tree

6 files changed

+30
-3
lines changed

6 files changed

+30
-3
lines changed

js/web/docs/webgpu-operators.md

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ Do not modify directly.*
5252
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
5353
| Greater | ai.onnx(7-8,9-12,13+) | |
5454
| GreaterOrEqual | ai.onnx(12-15,16+) | |
55+
| HardSigmoid | ai.onnx(6+) | |
5556
| If | ai.onnx(1-10,11-12,13-18,19+) | |
5657
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
5758
| LayerNormalization | ai.onnx(17+) | |

js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts

+1
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
8282
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
8383
['Greater', [binaryOps.greater]],
8484
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
85+
['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]],
8586
['InstanceNormalization', [instanceNorm]],
8687
['LayerNormalization', [layerNorm]],
8788
['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]],

js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts

+20
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,26 @@ export const sigmoid = (context: ComputeContext): void => {
242242
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`));
243243
};
244244

245+
export interface HardSigmoidAttributes extends AttributeWithCacheKey {
246+
readonly alpha: number;
247+
readonly beta: number;
248+
}
249+
250+
export const parseHardSigmoidAttributes = (attributes: Record<string, unknown>): HardSigmoidAttributes =>
251+
createAttributeWithCacheKey(attributes as {
252+
alpha: number;
253+
beta: number;
254+
});
255+
256+
export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttributes): void => {
257+
const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
258+
context.compute(createElementwiseProgramInfo(
259+
context.inputs[0], 'HardSigmoid',
260+
a => `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${
261+
attributes.beta})))`,
262+
undefined, attributes.cacheKey));
263+
};
264+
245265
export const sin = (context: ComputeContext): void => {
246266
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sin', 'sin'));
247267
};

js/web/test/suite-test-list.jsonc

+3-3
Original file line numberDiff line numberDiff line change
@@ -597,9 +597,9 @@
597597
// // "test_hardmax_example",
598598
// // "test_hardmax_negative_axis",
599599
// // "test_hardmax_one_hot",
600-
// // "test_hardsigmoid_default",
601-
// // "test_hardsigmoid_example",
602-
// // "test_hardsigmoid",
600+
"test_hardsigmoid_default",
601+
"test_hardsigmoid_example",
602+
"test_hardsigmoid",
603603
// // "test_hardswish_expanded",
604604
// // "test_hardswish",
605605
"test_if",

onnxruntime/core/providers/js/js_execution_provider.cc

+2
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai
9898
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Erf);
9999
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Sigmoid);
100100
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Sigmoid);
101+
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, HardSigmoid);
101102
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Log);
102103
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Log);
103104

@@ -392,6 +393,7 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
392393
KERNEL_CREATE_INFO(13, Erf),
393394
KERNEL_CREATE_INFO_VERSIONED(6, 12, Sigmoid),
394395
KERNEL_CREATE_INFO(13, Sigmoid),
396+
KERNEL_CREATE_INFO(6, HardSigmoid),
395397
KERNEL_CREATE_INFO_VERSIONED(6, 12, Log),
396398
KERNEL_CREATE_INFO(13, Log),
397399

onnxruntime/core/providers/js/operators/unary.cc

+3
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ JSEP_KERNEL_IMPL(Sigmoid, Sigmoid)
7777
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, Sigmoid)
7878
JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, Sigmoid)
7979

80+
JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(HardSigmoid, HardSigmoid, alpha, 0.2, beta, 0.5)
81+
JSEP_ELEMENTWISE_KERNEL(HardSigmoid, 6, HardSigmoid)
82+
8083
JSEP_KERNEL_IMPL(Log, Log)
8184
JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, Log)
8285
JSEP_ELEMENTWISE_KERNEL(Log, 13, Log)

0 commit comments

Comments
 (0)