Skip to content

Commit 9e69606

Browse files
authored
fix f16 for attention, enable slice and flatten for more types (#19262)
1 parent e96a038 commit 9e69606

File tree

3 files changed

+9
-13
lines changed

3 files changed

+9
-13
lines changed

js/web/lib/wasm/jsep/webgpu/ops/attention.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView
297297
298298
if (sum == 0) {
299299
for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) {
300-
x[offset + i] = ${fillVector('f32', components, 'uniforms.d_inv')};
300+
x[offset + i] = ${fillVector(elemValueType, components, 'uniforms.d_inv')};
301301
}
302302
} else {
303303
for (var i: u32 = 0; i < uniforms.elements_per_wg && i + localOffset < uniforms.d_comp; i++) {

onnxruntime/core/providers/js/operators/flatten.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
1313
kJsExecutionProvider,
1414
(*KernelDefBuilder::Create())
1515
.Alias(0, 0)
16-
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
16+
.TypeConstraint("T", JsepSupportedFloatTypes()),
1717
Flatten);
1818

1919
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
@@ -23,7 +23,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
2323
kJsExecutionProvider,
2424
(*KernelDefBuilder::Create())
2525
.Alias(0, 0)
26-
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
26+
.TypeConstraint("T", JsepSupportedFloatTypes()),
2727
Flatten);
2828

2929
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
@@ -33,7 +33,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
3333
kJsExecutionProvider,
3434
(*KernelDefBuilder::Create())
3535
.Alias(0, 0)
36-
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
36+
.TypeConstraint("T", JsepSupportedFloatTypes()),
3737
Flatten);
3838

3939
ONNX_OPERATOR_KERNEL_EX(
@@ -43,7 +43,7 @@ ONNX_OPERATOR_KERNEL_EX(
4343
kJsExecutionProvider,
4444
(*KernelDefBuilder::Create())
4545
.Alias(0, 0)
46-
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
46+
.TypeConstraint("T", JsepSupportedFloatTypes()),
4747
Flatten);
4848

4949
} // namespace js

onnxruntime/core/providers/js/operators/slice.cc

+4-8
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
1212
1, 9,
1313
kJsExecutionProvider,
1414
(*KernelDefBuilder::Create())
15-
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
16-
DataTypeImpl::GetTensorType<int32_t>()}),
15+
.TypeConstraint("T", JsepSupportedDataTypes()),
1716
Slice_1);
1817

1918
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
@@ -26,8 +25,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
2625
.InputMemoryType(OrtMemTypeCPU, 2)
2726
.InputMemoryType(OrtMemTypeCPU, 3)
2827
.InputMemoryType(OrtMemTypeCPU, 4)
29-
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
30-
DataTypeImpl::GetTensorType<int32_t>()}),
28+
.TypeConstraint("T", JsepSupportedDataTypes()),
3129
Slice);
3230

3331
ONNX_OPERATOR_VERSIONED_KERNEL_EX(
@@ -40,8 +38,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(
4038
.InputMemoryType(OrtMemTypeCPU, 2)
4139
.InputMemoryType(OrtMemTypeCPU, 3)
4240
.InputMemoryType(OrtMemTypeCPU, 4)
43-
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
44-
DataTypeImpl::GetTensorType<int32_t>()}),
41+
.TypeConstraint("T", JsepSupportedDataTypes()),
4542
Slice);
4643

4744
ONNX_OPERATOR_KERNEL_EX(
@@ -54,8 +51,7 @@ ONNX_OPERATOR_KERNEL_EX(
5451
.InputMemoryType(OrtMemTypeCPU, 2)
5552
.InputMemoryType(OrtMemTypeCPU, 3)
5653
.InputMemoryType(OrtMemTypeCPU, 4)
57-
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
58-
DataTypeImpl::GetTensorType<int32_t>()}),
54+
.TypeConstraint("T", JsepSupportedDataTypes()),
5955
Slice);
6056

6157
} // namespace js

0 commit comments

Comments
 (0)