Skip to content

Commit 3059be4

Browse files
tianleiwuankitm3k
authored andcommitted
[CUDA/ROCm] Conditionally support ArgMax and ArgMin for opset 12 and above (microsoft#22713)
### Description Based on microsoft#9700, and extend it to ArgMin as well. This pull request introduces several enhancements and fixes related to the `ArgMax` and `ArgMin` operators in the CUDA execution provider. The changes ensure proper handling of these operators across different versions and improve kernel registration and fallback mechanisms. Key changes include: #### Enhancements to `ArgMax` and `ArgMin` Operators: * Added new kernel class registrations for `ArgMax` and `ArgMin` for different data types and versions in `onnxruntime/core/providers/cuda/cuda_execution_provider.cc`. [[1]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R966-R972) [[2]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R1209-R1215) [[3]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R1657-R1659) [[4]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285L1825-L1827) [[5]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R1933-R1939) [[6]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R2174-R2180) * Introduced `ArgMaxOrArgMinNeedFallbackToCPU` function to handle fallback to CPU when the `select_last_index` attribute is set to 1, as CUDA does not support this attribute. [[1]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R2597-R2622) [[2]](diffhunk://#diff-57ba769b54dce57acd89df47140ede5f29ea670d61176096076701912d573285R2672-R2674) #### Macro and Kernel Registration Improvements: * Replaced `REGISTER_KERNEL_UNTIL_VERSIONED_TYPED` with `REGISTER_KERNEL_VERSIONED_RANGE_TYPED` and `REGISTER_KERNEL_VERSIONED_SINCE_TYPED` macros for better version handling. [[1]](diffhunk://#diff-ee5316fc3898058f70e942d9a84de36be4c7da09f144633a2504236430d5d033L19-R29) [[2]](diffhunk://#diff-ee5316fc3898058f70e942d9a84de36be4c7da09f144633a2504236430d5d033L40-R46) * Updated kernel registration for `ArgMax` and `ArgMin` to use the new macros, ensuring proper version handling and support for different data types. #### Safety Checks: * Added safety checks in the `ArgMax` and `ArgMin` classes to ensure `select_last_index` is not set to 1, as it is not supported on CUDA. [[1]](diffhunk://#diff-8ab09fef1f4a12cbf3b3432e509f8f1ef561e83c72778a0e047780060aeef6efL91-R99) [[2]](diffhunk://#diff-8ab09fef1f4a12cbf3b3432e509f8f1ef561e83c72778a0e047780060aeef6efL101-R117) #### Testing Enhancements: * Added new tests for `ArgMax` and `ArgMin` operators to verify behavior when `select_last_index` is set to 0, ensuring compatibility with both CPU and CUDA execution providers. [[1]](diffhunk://#diff-77affe1b70d1a9d38c2485f7c6b16ef2b6b541ed94dd727bc9b286f068f1481aR3340-R3360) [[2]](diffhunk://#diff-77affe1b70d1a9d38c2485f7c6b16ef2b6b541ed94dd727bc9b286f068f1481aR3679-R3699) ### Motivation and Context Improve CUDA kernel coverage for stable diffusion model and hence improve its performance on CUDA
1 parent 7881d16 commit 3059be4

File tree

7 files changed

+240
-34
lines changed

7 files changed

+240
-34
lines changed

Diff for: docs/OperatorKernels.md

+6-2
Original file line numberDiff line numberDiff line change
@@ -554,8 +554,12 @@ Do not modify directly.*
554554
|||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
555555
|Affine|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
556556
|And|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|7+|**T** = tensor(bool)<br/> **T1** = tensor(bool)|
557-
|ArgMax|*in* data:**T**<br> *out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
558-
|ArgMin|*in* data:**T**<br> *out* reduced:**tensor(int64)**|[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
557+
|ArgMax|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
558+
|||12|**T** = tensor(double), tensor(float), tensor(float16)|
559+
|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
560+
|ArgMin|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
561+
|||12|**T** = tensor(double), tensor(float), tensor(float16)|
562+
|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
559563
|AveragePool|*in* X:**T**<br> *out* Y:**T**|11+|**T** = tensor(double), tensor(float), tensor(float16)|
560564
|||10|**T** = tensor(double), tensor(float), tensor(float16)|
561565
|||[7, 9]|**T** = tensor(double), tensor(float), tensor(float16)|

Diff for: onnxruntime/core/providers/cuda/cuda_execution_provider.cc

+60-3
Original file line numberDiff line numberDiff line change
@@ -963,6 +963,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
963963
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout);
964964
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum);
965965

966+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax);
967+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax);
968+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax);
969+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMin);
970+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMin);
971+
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMin);
972+
966973
// OpSet 13
967974
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow);
968975
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add);
@@ -1199,6 +1206,13 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kO
11991206
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear);
12001207
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear);
12011208

1209+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMax);
1210+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMax);
1211+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax);
1212+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMin);
1213+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMin);
1214+
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMin);
1215+
12021216
// OpSet 14
12031217
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, CumSum);
12041218
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Relu);
@@ -1640,6 +1654,9 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
16401654
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMin)>,
16411655
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMin)>,
16421656
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMin)>,
1657+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax)>,
1658+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMax)>,
1659+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMax)>,
16431660
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, float, ReduceL1)>,
16441661
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, double, ReduceL1)>,
16451662
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 17, MLFloat16, ReduceL1)>,
@@ -1822,9 +1839,6 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
18221839
19, IsInf)>,
18231840

18241841
// opset 11
1825-
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, float, ArgMax)>,
1826-
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, double, ArgMax)>,
1827-
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 11, MLFloat16, ArgMax)>,
18281842
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, Compress)>,
18291843
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Concat)>,
18301844
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 11, 12, Flatten)>,
@@ -1916,6 +1930,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
19161930
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, Dropout)>,
19171931
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, Einsum)>,
19181932

1933+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMax)>,
1934+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMax)>,
1935+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMax)>,
1936+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, float, ArgMin)>,
1937+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, double, ArgMin)>,
1938+
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 12, 12, MLFloat16, ArgMin)>,
1939+
19191940
// OpSet 13
19201941
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 14, Pow)>,
19211942
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 13, int32_t, Add)>,
@@ -2150,6 +2171,13 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
21502171
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, int8_t, DequantizeLinear)>,
21512172
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, 18, uint8_t, DequantizeLinear)>,
21522173

2174+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMax)>,
2175+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMax)>,
2176+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMax)>,
2177+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, ArgMin)>,
2178+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, ArgMin)>,
2179+
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, ArgMin)>,
2180+
21532181
// OpSet 14
21542182
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, CumSum)>,
21552183
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Relu)>,
@@ -2566,6 +2594,32 @@ static bool CastNeedFallbackToCPU(const onnxruntime::Node& node) {
25662594
return false;
25672595
}
25682596

2597+
static bool ArgMaxOrArgMinNeedFallbackToCPU(const onnxruntime::Node& node) {
2598+
// Opset 12 introduced the attribute "select_last_index"
2599+
if (node.SinceVersion() >= 12) {
2600+
const auto& node_attributes = node.GetAttributes();
2601+
2602+
for (auto& attr : node_attributes) {
2603+
auto& attr_name = attr.first;
2604+
auto& attr_value = attr.second;
2605+
2606+
// CuDNN doesn't support picking the last index in case of encountering
2607+
// duplicate max values.
2608+
// CuDNN's API doc doesn't mention what happens in case duplicates are encountered,
2609+
// but based on testing, the results seem to indicate a "stable" implementation
2610+
// (i.e.) relative ordering is preserved which is the expected behavior when the
2611+
// attribute takes on the default value (most common use-case for this operator).
2612+
if ("select_last_index" == attr_name) {
2613+
if (attr_value.i() != 0) {
2614+
return true;
2615+
}
2616+
}
2617+
}
2618+
}
2619+
2620+
return false;
2621+
}
2622+
25692623
std::unique_ptr<onnxruntime::IDataTransfer> CUDAExecutionProvider::GetDataTransfer() const {
25702624
return std::make_unique<onnxruntime::GPUDataTransfer>();
25712625
}
@@ -2615,6 +2669,9 @@ CUDAExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph,
26152669
} else if ("ConvTranspose" == node.OpType()) {
26162670
not_supported = ConvTransposeNeedFallbackToCPU(node, logger, graph, IsNHWCPreferred());
26172671
force_inside = !not_supported;
2672+
} else if ("ArgMax" == node.OpType() || "ArgMin" == node.OpType()) {
2673+
not_supported = ArgMaxOrArgMinNeedFallbackToCPU(node);
2674+
force_inside = !not_supported;
26182675
} else if ("Cast" == node.OpType()) {
26192676
not_supported = CastNeedFallbackToCPU(node);
26202677
// cast is not compute heavy, and may be placed outside

Diff for: onnxruntime/core/providers/cuda/reduction/reduction_ops.cc

+16-12
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@ using namespace onnxruntime::common;
1616
namespace onnxruntime {
1717
namespace cuda {
1818

19-
#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \
19+
#define REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, begin, end) \
2020
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
2121
name, \
2222
kOnnxDomain, \
23-
1, end, \
23+
begin, end, \
2424
T, \
2525
kCudaExecutionProvider, \
2626
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
2727
name<T>);
2828

29-
#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \
29+
#define REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, version) \
3030
ONNX_OPERATOR_TYPED_KERNEL_EX( \
3131
name, \
3232
kOnnxDomain, \
@@ -37,8 +37,13 @@ namespace cuda {
3737
name<T>);
3838

3939
#define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \
40-
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \
41-
REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur)
40+
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, last) \
41+
REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, cur)
42+
43+
#define REGISTER_KERNEL_ARGMIN_OR_ARGMAX(name, T) \
44+
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, 11) \
45+
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 12, 12) \
46+
REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, 13)
4247

4348
// TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored.
4449
template <bool allow_multi_axes>
@@ -829,14 +834,13 @@ template std::unique_ptr<Tensor> ReduceCompute<MLFloat16, CUDNN_REDUCE_TENSOR_NO
829834

830835
} // namespace ReductionOps
831836

832-
// CUDA ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet
833-
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, MLFloat16, 11)
834-
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, float, 11)
835-
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, double, 11)
837+
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, MLFloat16)
838+
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, float)
839+
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, double)
836840

837-
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, MLFloat16, 11)
838-
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, float, 11)
839-
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, double, 11)
841+
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, MLFloat16)
842+
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, float)
843+
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, double)
840844

841845
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, MLFloat16, 17, 18)
842846
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, float, 17, 18)

Diff for: onnxruntime/core/providers/cuda/reduction/reduction_ops.h

+18-2
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,15 @@ class ReduceKernel : public CudaKernel, public ReduceKernelBase<allow_multi_axes
8888
template <typename T>
8989
class ArgMax final : public ReduceKernel<false> {
9090
public:
91-
ArgMax(const OpKernelInfo& info) : ReduceKernel<false>(info) {}
91+
ArgMax(const OpKernelInfo& info) : ReduceKernel<false>(info) {
92+
// The following is just a safety check.
93+
// The logic in ArgMaxOrArgMinNeedFallbackToCPU() makes sure to not assign ArgMax
94+
// nodes with select_last_index == 1 to the CUDA EP.
95+
int64_t select_last_index = 0;
96+
if (info.GetAttr<int64_t>("select_last_index", &select_last_index).IsOK()) {
97+
ORT_ENFORCE(select_last_index == 0, "select_last_index as 1 is not supported on CUDA");
98+
}
99+
}
92100

93101
Status ComputeInternal(OpKernelContext* ctx) const override {
94102
return ComputeImpl<T, CUDNN_REDUCE_TENSOR_FLATTENED_INDICES>(ctx, CUDNN_REDUCE_TENSOR_MAX);
@@ -98,7 +106,15 @@ class ArgMax final : public ReduceKernel<false> {
98106
template <typename T>
99107
class ArgMin final : public ReduceKernel<false> {
100108
public:
101-
ArgMin(const OpKernelInfo& info) : ReduceKernel<false>(info) {}
109+
ArgMin(const OpKernelInfo& info) : ReduceKernel<false>(info) {
110+
// The following is just a safety check.
111+
// The logic in ArgMaxOrArgMinNeedFallbackToCPU() makes sure to not assign ArgMin
112+
// nodes with select_last_index == 1 to the CUDA EP.
113+
int64_t select_last_index = 0;
114+
if (info.GetAttr<int64_t>("select_last_index", &select_last_index).IsOK()) {
115+
ORT_ENFORCE(select_last_index == 0, "select_last_index as 1 is not supported on CUDA");
116+
}
117+
}
102118

103119
Status ComputeInternal(OpKernelContext* ctx) const override {
104120
return ComputeImpl<T, CUDNN_REDUCE_TENSOR_FLATTENED_INDICES>(ctx, CUDNN_REDUCE_TENSOR_MIN);

Diff for: onnxruntime/core/providers/rocm/reduction/reduction_ops.cc

+16-12
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,17 @@ using namespace onnxruntime::common;
1616
namespace onnxruntime {
1717
namespace rocm {
1818

19-
#define REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, end) \
19+
#define REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, begin, end) \
2020
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
2121
name, \
2222
kOnnxDomain, \
23-
1, end, \
23+
begin, end, \
2424
T, \
2525
kRocmExecutionProvider, \
2626
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
2727
name<T>);
2828

29-
#define REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, version) \
29+
#define REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, version) \
3030
ONNX_OPERATOR_TYPED_KERNEL_EX( \
3131
name, \
3232
kOnnxDomain, \
@@ -37,8 +37,13 @@ namespace rocm {
3737
name<T>);
3838

3939
#define REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(name, T, last, cur) \
40-
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(name, T, last) \
41-
REGISTER_KERNEL_TYPED_AXES_INPUT(name, T, cur)
40+
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, last) \
41+
REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, cur)
42+
43+
#define REGISTER_KERNEL_ARGMIN_OR_ARGMAX(name, T) \
44+
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 1, 11) \
45+
REGISTER_KERNEL_VERSIONED_RANGE_TYPED(name, T, 12, 12) \
46+
REGISTER_KERNEL_VERSIONED_SINCE_TYPED(name, T, 13)
4247

4348
// TODO ReduceKernel::ReduceKernelShared() is still used by some other training classes though it's not used here - this should be refactored.
4449
template <bool allow_multi_axes>
@@ -830,14 +835,13 @@ template std::unique_ptr<Tensor> ReduceCompute<MLFloat16, MIOPEN_REDUCE_TENSOR_N
830835

831836
} // namespace ReductionOps
832837

833-
// ROCM ArgMax/ArgMin doesn't have OpSet12+ implementation (with select_last_index attr) yet
834-
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, MLFloat16, 11)
835-
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, float, 11)
836-
// REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMax, double, 11)
838+
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, MLFloat16)
839+
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, float)
840+
// REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMax, double)
837841

838-
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, MLFloat16, 11)
839-
REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, float, 11)
840-
// REGISTER_KERNEL_UNTIL_VERSIONED_TYPED(ArgMin, double, 11)
842+
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, MLFloat16)
843+
REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, float)
844+
// REGISTER_KERNEL_ARGMIN_OR_ARGMAX(ArgMin, double)
841845

842846
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, MLFloat16, 17, 18)
843847
REGISTER_KERNEL_TYPED_AXES_INPUT_WITH_VERSIONED(ReduceMax, float, 17, 18)

0 commit comments

Comments
 (0)