Skip to content

Commit 7e0d424

Browse files
authored
accumulate in fp32 for Reduce* (#19868)
1 parent 28ad6c3 commit 7e0d424

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

js/web/lib/wasm/jsep/webgpu/ops/reduce-shared.ts

+5-5
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ export const createReduceSharedProgramInfo =
131131
const workgroupSize = 32;
132132

133133
const sharedMemorySnippet = `
134-
var<workgroup> aBestValues : array<${output.type.storage}, ${workgroupSize}>;
134+
var<workgroup> aBestValues : array<f32, ${workgroupSize}>;
135135
`;
136136

137137
const getShaderSource = (shaderHelper: ShaderHelper) => `
@@ -145,10 +145,10 @@ export const createReduceSharedProgramInfo =
145145
let outputIndex = global_idx / ${workgroupSize};
146146
let offset = outputIndex * uniforms.reduceSize;
147147
148-
var bestValue = ${output.type.storage}(${reduceInitValues[reduceType]});
148+
var bestValue = f32(${reduceInitValues[reduceType]});
149149
let Length = uniforms.reduceSize;
150150
for (var k = local_idx; k < Length; k = k + ${workgroupSize}) {
151-
let candidate = ${output.type.storage}(${input.getByOffset('offset + k')});
151+
let candidate = f32(${input.getByOffset('offset + k')});
152152
bestValue = ${reduceOps[reduceType]};
153153
}
154154
aBestValues[local_idx] = bestValue;
@@ -172,8 +172,8 @@ export const createReduceSharedProgramInfo =
172172
output.setByOffset(
173173
'outputIndex',
174174
`${
175-
reduceType === 'mean' ? `bestValue / ${output.type.storage}(uniforms.reduceSize)` :
176-
`${reduceOutputValues[reduceType]}`}`)};
175+
reduceType === 'mean' ? `${output.type.storage}(bestValue / f32(uniforms.reduceSize))` :
176+
`${output.type.storage}(${reduceOutputValues[reduceType]})`}`)};
177177
}
178178
}`;
179179

0 commit comments

Comments
 (0)