Skip to content

Commit b5ee4ac

Browse files
authored
[js/webgpu] support GridSample operator (#22652)
### Description <!-- Describe your changes. --> ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent d9b9168 commit b5ee4ac

File tree

7 files changed

+358
-8
lines changed

7 files changed

+358
-8
lines changed

js/web/docs/webgpu-operators.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ Do not modify directly.*
5656
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
5757
| Greater | ai.onnx(7-8,9-12,13+) | |
5858
| GreaterOrEqual | ai.onnx(12-15,16+) | |
59+
| GridSample | ai.onnx(16-19); com.ms.internal.nhwc(16-19) | |
5960
| GroupQueryAttention | com.microsoft(1+) | |
6061
| HardSigmoid | ai.onnx(6+) | |
6162
| If | ai.onnx(1-10,11-12,13-18,19-20,21+) | |

js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import { gather, parseGatherAttributes } from './ops/gather';
1919
import { gatherBlockQuantized, parseGatherBlockQuantizedAttributes } from './ops/gather-block-quantized';
2020
import { gatherElements, parseGatherElementsAttributes } from './ops/gather-elements';
2121
import { gemm, parseGemmAttributes } from './ops/gemm';
22+
import { gridSample, parseGridSampleAttributes } from './ops/grid-sample';
2223
import { groupQueryAttention } from './ops/group-query-attention';
2324
import { instanceNorm } from './ops/instance-norm';
2425
import { layerNorm } from './ops/layer-norm';
@@ -104,6 +105,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
104105
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
105106
['Greater', [binaryOps.greater]],
106107
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
108+
['GridSample', [gridSample, parseGridSampleAttributes]],
107109
['GroupQueryAttention', [groupQueryAttention]],
108110
['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]],
109111
['InstanceNormalization', [instanceNorm]],
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
import { DataType } from '../../../wasm-common';
5+
import { TensorView } from '../../tensor-view';
6+
import { ShapeUtil } from '../../util';
7+
import { AttributeWithCacheKey, createAttributeWithCacheKey } from '../attribute-with-cache-key';
8+
import { ComputeContext, ProgramInfo, ProgramUniform } from '../types';
9+
10+
import { createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper } from './common';
11+
12+
let [idxN, idxC, idxH, idxW] = [0, 1, 2, 3]; // NCHW
13+
type Mode = 'bilinear' | 'nearest' | 'bicubic';
14+
type PaddingMode = 'zeros' | 'border' | 'reflection';
15+
type Format = 'NHWC' | 'NCHW';
16+
export interface GridSampeAttributes extends AttributeWithCacheKey {
17+
alignCorners: number;
18+
mode: Mode;
19+
paddingMode: PaddingMode;
20+
format: Format;
21+
}
22+
23+
const validateInputs = (inputs: readonly TensorView[]): void => {
24+
if (inputs[0].dims.length !== 4) {
25+
throw new Error('only 4-D tensor is supported.');
26+
}
27+
if (inputs[0].dims.length !== inputs[1].dims.length) {
28+
throw new Error('input dimensions must be equal to grid dimensions');
29+
}
30+
31+
if (inputs[0].dims.length - 2 !== inputs[1].dims[inputs[1].dims.length - 1]) {
32+
throw new Error(`last dimension of grid must be equal to ${inputs[0].dims.length - 2}`);
33+
}
34+
35+
if (inputs[0].dims[0] !== inputs[1].dims[0]) {
36+
throw new Error('grid batch size must match input batch size');
37+
}
38+
};
39+
40+
const gsGetCubicCoeffs = `
41+
fn gs_get_cubic_coeffs(x: f32) -> vec4<f32> {
42+
let cubic_alpha = -0.75f;
43+
let x_abs = abs(x);
44+
var coeffs: vec4<f32>;
45+
coeffs[0] = (((cubic_alpha * (x_abs + 1) - 5 * cubic_alpha) * (x_abs + 1) + 8 * cubic_alpha) * (x_abs + 1) - 4 * cubic_alpha);
46+
coeffs[1] = (((cubic_alpha + 2) * x_abs - (cubic_alpha + 3)) * x_abs * x_abs + 1);
47+
coeffs[2] = (((cubic_alpha + 2) * (1 - x_abs) - (cubic_alpha + 3)) * (1 - x_abs) * (1 - x_abs) + 1);
48+
coeffs[3] = (((cubic_alpha * (2 - x_abs) - 5 * cubic_alpha) * (2 - x_abs) + 8 * cubic_alpha) * (2 - x_abs) - 4 * cubic_alpha);
49+
return coeffs;
50+
}
51+
`;
52+
53+
const gsBicubicInterpolate = (dataType: string): string => `
54+
fn gs_bicubic_interpolate(p: mat4x4<${dataType}>, x: f32, y: f32) -> ${dataType} {
55+
var v: vec4<f32>;
56+
var coeffs = gs_get_cubic_coeffs(x);
57+
for (var i = 0; i < 4; i++) {
58+
v[i] = coeffs[0] * p[i][0] + coeffs[1] * p[i][1] + coeffs[2] * p[i][2] + coeffs[3] * p[i][3];
59+
}
60+
coeffs = gs_get_cubic_coeffs(y);
61+
let pixel = ${dataType}(coeffs[0] * v[0] + coeffs[1] * v[1] + coeffs[2] * v[2] + coeffs[3] * v[3]);
62+
return pixel;
63+
}
64+
`;
65+
66+
const gsDenormalize = (attributes: GridSampeAttributes): string => `
67+
fn gs_denormalize(n: f32, length: i32) -> f32 {
68+
${
69+
attributes.alignCorners === 0
70+
? `
71+
// alignCorners: false => [-1, 1] to [-0.5, length - 0.5]
72+
return ((n + 1.0) * f32(length) - 1.0) / 2.0;
73+
`
74+
: `
75+
// alignCorners: true => [-1, 1] to [0, length - 1]
76+
return (n + 1.0) / 2.0 * (f32(length - 1));
77+
`
78+
}
79+
}
80+
`;
81+
82+
const gsReflect = (attributes: GridSampeAttributes): string => `
83+
${
84+
attributes.paddingMode === 'reflection'
85+
? `
86+
fn gs_reflect(x: i32, x_min: f32, x_max: f32) -> u32 {
87+
var dx = 0.0;
88+
var fx = f32(x);
89+
let range = x_max - x_min;
90+
if (fx < x_min) {
91+
dx = x_min - fx;
92+
let n = u32(dx / range);
93+
let r = dx - f32(n) * range;
94+
if (n % 2 == 0) {
95+
fx = x_min + r;
96+
} else {
97+
fx = x_max - r;
98+
}
99+
} else if (fx > x_max) {
100+
dx = fx - x_max;
101+
let n = u32(dx / range);
102+
let r = dx - f32(n) * range;
103+
if (n % 2 == 0) {
104+
fx = x_max - r;
105+
} else {
106+
fx = x_min + r;
107+
}
108+
}
109+
return u32(fx);
110+
}`
111+
: ''
112+
}
113+
`;
114+
115+
const pixelAtGrid = (input: IndicesHelper, dataType: string, attributes: GridSampeAttributes): string =>
116+
`
117+
fn pixel_at_grid(r: i32, c: i32, H: i32, W: i32, batch: u32, channel: u32, border: vec4<f32>) -> ${dataType} {
118+
var pixel = ${dataType}(0);
119+
var indices = vec4<u32>(0);
120+
indices[${idxN}] = batch;
121+
indices[${idxC}] = channel;` +
122+
(() => {
123+
switch (attributes.paddingMode) {
124+
case 'zeros':
125+
return `
126+
if (r >= 0 && r < H && c >=0 && c < W) {
127+
indices[${idxH}] = u32(r);
128+
indices[${idxW}] = u32(c);
129+
}
130+
`;
131+
case 'border':
132+
return `
133+
indices[${idxH}] = u32(clamp(r, 0, H - 1));
134+
indices[${idxW}] = u32(clamp(c, 0, W - 1));
135+
`;
136+
case 'reflection':
137+
return `
138+
indices[${idxH}] = gs_reflect(r, border[1], border[3]);
139+
indices[${idxW}] = gs_reflect(c, border[0], border[2]);
140+
`;
141+
default:
142+
throw new Error(`padding mode ${attributes.paddingMode} is not supported`);
143+
}
144+
})() +
145+
`
146+
return ${input.getByIndices('indices')};
147+
}
148+
`;
149+
150+
const computePixel = (output: IndicesHelper, dataType: string, attributes: GridSampeAttributes): string =>
151+
(() => {
152+
switch (attributes.mode) {
153+
case 'nearest':
154+
return `
155+
let result = pixel_at_grid(i32(round(y)), i32(round(x)), H_in, W_in, indices[${idxN}], indices[${idxC}], border);
156+
`;
157+
case 'bilinear':
158+
return `
159+
let x1 = i32(floor(x));
160+
let y1 = i32(floor(y));
161+
let x2 = x1 + 1;
162+
let y2 = y1 + 1;
163+
164+
let p11 = pixel_at_grid(y1, x1, H_in, W_in, indices[${idxN}], indices[${idxC}], border);
165+
let p12 = pixel_at_grid(y1, x2, H_in, W_in, indices[${idxN}], indices[${idxC}], border);
166+
let p21 = pixel_at_grid(y2, x1, H_in, W_in, indices[${idxN}], indices[${idxC}], border);
167+
let p22 = pixel_at_grid(y2, x2, H_in, W_in, indices[${idxN}], indices[${idxC}], border);
168+
169+
let dx2 = ${dataType}(f32(x2) - x);
170+
let dx1 = ${dataType}(x - f32(x1));
171+
let dy2 = ${dataType}(f32(y2) - y);
172+
let dy1 = ${dataType}(y - f32(y1));
173+
let result = dy2 * (dx2 * p11 + dx1 * p12) + dy1 * (dx2 * p21 + dx1 * p22);
174+
`;
175+
case 'bicubic':
176+
return `
177+
let x0 = i32(floor(x)) - 1;
178+
let y0 = i32(floor(y)) - 1;
179+
var p: mat4x4<${dataType}>;
180+
for (var h = 0; h < 4; h++) {
181+
for (var w = 0; w < 4; w++) {
182+
p[h][w] = pixel_at_grid(h + y0, w + x0, H_in, W_in, indices[${idxN}], indices[${idxC}], border);
183+
}
184+
}
185+
186+
let dx = x - f32(x0 + 1);
187+
let dy = y - f32(y0 + 1);
188+
let result = gs_bicubic_interpolate(p, dx, dy);
189+
`;
190+
default:
191+
throw new Error(`mode ${attributes.mode} is not supported`);
192+
}
193+
})() + `${output.setByOffset('global_idx', 'result')}`;
194+
195+
const createGridSampleProgramInfo = (inputs: readonly TensorView[], attributes: GridSampeAttributes): ProgramInfo => {
196+
const x = inputVariable('x', inputs[0].dataType, inputs[0].dims.length);
197+
// discard last dimension for using vec2 to access grid data
198+
const gridShape = [inputs[1].dims[0], inputs[1].dims[1], inputs[1].dims[2]];
199+
const grid = inputVariable('grid', inputs[1].dataType, gridShape.length, 2);
200+
let outputShape = [inputs[0].dims[0], inputs[0].dims[1], inputs[1].dims[1], inputs[1].dims[2]];
201+
if (attributes.format === 'NHWC') {
202+
outputShape = [inputs[0].dims[0], inputs[1].dims[1], inputs[1].dims[2], inputs[0].dims[3]];
203+
[idxN, idxC, idxH, idxW] = [0, 3, 1, 2];
204+
}
205+
const output = outputVariable('output', inputs[0].dataType, outputShape.length);
206+
const dataType = x.type.value;
207+
const outputSize = ShapeUtil.size(outputShape);
208+
209+
const programUniforms: ProgramUniform[] = [
210+
{ type: DataType.uint32, data: outputSize },
211+
...createTensorShapeVariables(inputs[0].dims, gridShape, outputShape),
212+
];
213+
214+
const getShaderSource = (shaderHelper: ShaderHelper) => `
215+
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(x, grid, output)}
216+
${gsGetCubicCoeffs}
217+
${gsBicubicInterpolate(dataType)}
218+
${gsDenormalize(attributes)}
219+
${gsReflect(attributes)}
220+
${pixelAtGrid(x, dataType, attributes)}
221+
222+
${shaderHelper.mainStart()}
223+
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
224+
let H_in = i32(uniforms.x_shape[${idxH}]);
225+
let W_in = i32(uniforms.x_shape[${idxW}]);
226+
227+
${
228+
attributes.alignCorners === 0
229+
? `
230+
let x_min = -0.5;
231+
let x_max = f32(W_in) - 0.5;
232+
let y_min = -0.5;
233+
let y_max = f32(H_in) - 0.5;
234+
`
235+
: `
236+
let x_min = 0.0;
237+
let x_max = f32(W_in) - 1.0;
238+
let y_min = 0.0;
239+
let y_max = f32(H_in) - 1.0;
240+
`
241+
};
242+
let border = vec4<f32>(x_min, y_min, x_max, y_max);
243+
244+
let indices = ${output.offsetToIndices('global_idx')};
245+
var grid_indices = vec3<u32>(indices[${idxN}], indices[${idxH}], indices[${idxW}]);
246+
let nxy = ${grid.getByIndices('grid_indices')};
247+
var x = gs_denormalize(f32(nxy[0]), W_in);
248+
var y = gs_denormalize(f32(nxy[1]), H_in);
249+
250+
${computePixel(output, dataType, attributes)}
251+
}`;
252+
253+
return {
254+
name: 'GridSample',
255+
shaderCache: { hint: `${attributes.cacheKey}`, inputDependencies: ['type', 'type'] },
256+
getRunData: (inputs) => {
257+
const outputSize = ShapeUtil.size(outputShape);
258+
return {
259+
outputs: [{ dims: outputShape, dataType: inputs[0].dataType }],
260+
dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */) },
261+
programUniforms,
262+
};
263+
},
264+
getShaderSource,
265+
};
266+
};
267+
268+
export const gridSample = (context: ComputeContext, attributes: GridSampeAttributes): void => {
269+
validateInputs(context.inputs);
270+
context.compute(createGridSampleProgramInfo(context.inputs, attributes));
271+
};
272+
273+
export const parseGridSampleAttributes = (attributes: Record<string, unknown>): GridSampeAttributes =>
274+
createAttributeWithCacheKey({
275+
alignCorners: attributes.align_corners as number,
276+
mode: attributes.mode as Mode,
277+
paddingMode: attributes.padding_mode as PaddingMode,
278+
format: attributes.format as Format,
279+
});

js/web/test/suite-test-list.jsonc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -570,14 +570,14 @@
570570
"test_greater_equal_expanded",
571571
"test_greater_equal",
572572
"test_greater",
573-
// // "test_gridsample_aligncorners_true",
574-
// // "test_gridsample_bicubic",
575-
// // "test_gridsample_bilinear",
576-
// // "test_gridsample_border_padding",
577-
// // "test_gridsample_nearest",
578-
// // "test_gridsample_reflection_padding",
579-
// // "test_gridsample_zeros_padding",
580-
// // "test_gridsample",
573+
"test_gridsample_aligncorners_true",
574+
"test_gridsample_bicubic",
575+
"test_gridsample_bilinear",
576+
"test_gridsample_border_padding",
577+
"test_gridsample_nearest",
578+
"test_gridsample_reflection_padding",
579+
"test_gridsample_zeros_padding",
580+
"test_gridsample",
581581
// // "test_gru_batchwise",
582582
// // "test_gru_defaults",
583583
// // "test_gru_seq_length",

onnxruntime/core/providers/js/js_execution_provider.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2
400400
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, int8_t, DequantizeLinear);
401401
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, int32_t, DequantizeLinear);
402402

403+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, 19, GridSample);
404+
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 16, 19, GridSample);
405+
403406
std::unique_ptr<KernelRegistry> RegisterKernels() {
404407
auto kernel_registry = std::make_unique<onnxruntime::KernelRegistry>();
405408

@@ -728,6 +731,9 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
728731
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, uint8_t, DequantizeLinear)>,
729732
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, int8_t, DequantizeLinear)>,
730733
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 21, int32_t, DequantizeLinear)>,
734+
735+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, 19, GridSample)>,
736+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 16, 19, GridSample)>,
731737
};
732738

733739
for (auto& function_table_entry : function_table) {
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#include "grid_sample.h"
5+
6+
namespace onnxruntime {
7+
namespace js {
8+
9+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
10+
GridSample,
11+
kMSInternalNHWCDomain,
12+
16, 19,
13+
kJsExecutionProvider,
14+
KernelDefBuilder()
15+
.TypeConstraint("T1", JsepSupportedDataTypes())
16+
.TypeConstraint("T2", JsepSupportedFloatTypes()),
17+
GridSample<true>);
18+
19+
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
20+
GridSample,
21+
kOnnxDomain,
22+
16, 19,
23+
kJsExecutionProvider,
24+
KernelDefBuilder()
25+
.TypeConstraint("T1", JsepSupportedDataTypes())
26+
.TypeConstraint("T2", JsepSupportedFloatTypes()),
27+
GridSample<false>);
28+
29+
} // namespace js
30+
} // namespace onnxruntime

0 commit comments

Comments
 (0)