Skip to content

Commit 4645726

Browse files
authored
fix for webgl lrn (microsoft#15236)
fix issue that resulted in wrong results for lrn on webgpu
1 parent 9f942e1 commit 4645726

File tree

1 file changed

+4
-15
lines changed
  • js/web/lib/onnxjs/backends/webgl/ops

1 file changed

+4
-15
lines changed

js/web/lib/onnxjs/backends/webgl/ops/lrn.ts

+4-15
Original file line numberDiff line numberDiff line change
@@ -42,24 +42,14 @@ const lrnProgramMetadata = {
4242
inputTypes: [TextureType.unpacked]
4343
};
4444

45-
function getOutputExpression(attributes: LrnAttributes): string {
46-
let expression = `float(${attributes.bias}) + float(${attributes.alpha}) * square_sum`;
47-
if (attributes.beta === 0.5) {
48-
expression = `inversesqrt(${expression})`;
49-
} else if (attributes.beta === 1.0) {
50-
expression = `1.0/(${expression})`;
51-
} else {
52-
expression = `exp(log(${expression})) * float(-${attributes.beta})`;
53-
}
54-
return `x * ${expression}`;
55-
}
56-
5745
function createLrnProgramInfo(inputs: Tensor[], attributes: LrnAttributes): ProgramInfo {
5846
const C = inputs[0].dims[1];
59-
6047
const rank = inputs[0].dims.length;
6148
const from = -Math.floor((attributes.size - 1) / 2);
6249
const to = Math.ceil((attributes.size - 1) / 2);
50+
const alpha = `float(${attributes.alpha}) / float(${attributes.size})`;
51+
const bias = `float(${attributes.bias})`;
52+
const beta = `float(${attributes.beta})`;
6353

6454
const shaderSource = `
6555
float process(int indices[${rank}]) {
@@ -75,8 +65,7 @@ function createLrnProgramInfo(inputs: Tensor[], attributes: LrnAttributes): Prog
7565
square_sum += j * j;
7666
}
7767
}
78-
79-
return ${getOutputExpression(attributes)};
68+
return x / pow(${bias} + ${alpha} * square_sum, ${beta});
8069
}`;
8170
return {
8271
...lrnProgramMetadata,

0 commit comments

Comments
 (0)