Skip to content

Commit 6883fdc

Browse files
dtang317nums11
authored andcommitted
Fix GRU tests (microsoft#22716)
### Description Many GRU tests were being skipped due to an error in MLOperatorAuthorImpl.cpp. The issue was caused by activation function names not being capitalized (e.g., ‘sigmoid’), while The AttrValue was using mixed cases (e.g., ‘Sigmoid’, ‘LeakyRelu’), which resulted in an ‘unsupported activation function’ error in DMLOperatorRecurrentNeuralNetwork.cpp. This PR fixes the issue by making the DML EP activation function name case-insensitive, and capitalizing the activation function names in the tests. ref PR: microsoft#15914 ref bug: https://dev.azure.com/microsoft/OS/_workitems/edit/44571772 ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: nums11 <[email protected]>
1 parent 63533ea commit 6883fdc

File tree

2 files changed

+45
-179
lines changed

2 files changed

+45
-179
lines changed

onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRecurrentNeuralNetwork.cpp

+25-19
Original file line numberDiff line numberDiff line change
@@ -127,51 +127,51 @@ class DmlOperatorRecurrentBase: public DmlOperator, public RecurrentHelper
127127
DML_OPERATOR_DESC& desc = descs[i];
128128
ActivationOperatorDescUnion& activationDesc = m_activationDescs[i];
129129
desc.Desc = &activationDesc;
130-
131-
if (activationName == AttrValue::ActivationRelu)
130+
131+
if (CompareActivationName(activationName, AttrValue::ActivationRelu))
132132
{
133133
desc.Type = DML_OPERATOR_ACTIVATION_RELU;
134-
}
135-
else if (activationName == AttrValue::ActivationLeakyRelu)
134+
}
135+
else if (CompareActivationName(activationName, AttrValue::ActivationLeakyRelu))
136136
{
137137
desc.Type = DML_OPERATOR_ACTIVATION_LEAKY_RELU;
138138
activationDesc.leakyRelu.Alpha = NextAlpha(desc.Type);
139139
}
140-
else if (activationName == AttrValue::ActivationThresholdedRelu)
140+
else if (CompareActivationName(activationName, AttrValue::ActivationThresholdedRelu))
141141
{
142142
desc.Type = DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU;
143143
activationDesc.thresholdedRelu.Alpha = NextAlpha(desc.Type);
144-
}
145-
else if (activationName == AttrValue::ActivationTanh)
144+
}
145+
else if (CompareActivationName(activationName, AttrValue::ActivationTanh))
146146
{
147147
desc.Type = DML_OPERATOR_ACTIVATION_TANH;
148-
}
149-
else if (activationName == AttrValue::ActivationScaledTanh)
148+
}
149+
else if (CompareActivationName(activationName, AttrValue::ActivationScaledTanh))
150150
{
151151
desc.Type = DML_OPERATOR_ACTIVATION_SCALED_TANH;
152152
activationDesc.scaledTanh.Alpha = NextAlpha(desc.Type);
153153
activationDesc.scaledTanh.Beta = NextBeta(desc.Type);
154-
}
155-
else if (activationName == AttrValue::ActivationSigmoid)
154+
}
155+
else if (CompareActivationName(activationName, AttrValue::ActivationSigmoid))
156156
{
157157
desc.Type = DML_OPERATOR_ACTIVATION_SIGMOID;
158-
}
159-
else if (activationName == AttrValue::ActivationSigmoidHard)
158+
}
159+
else if (CompareActivationName(activationName, AttrValue::ActivationSigmoidHard))
160160
{
161161
desc.Type = DML_OPERATOR_ACTIVATION_HARD_SIGMOID;
162162
activationDesc.hardSigmoid.Alpha = NextAlpha(desc.Type);
163163
activationDesc.hardSigmoid.Beta = NextBeta(desc.Type);
164-
}
165-
else if (activationName == AttrValue::ActivationElu)
164+
}
165+
else if (CompareActivationName(activationName, AttrValue::ActivationElu))
166166
{
167167
desc.Type = DML_OPERATOR_ACTIVATION_ELU;
168168
activationDesc.elu.Alpha = NextAlpha(desc.Type);
169-
}
170-
else if (activationName == AttrValue::ActivationSoftsign)
169+
}
170+
else if (CompareActivationName(activationName, AttrValue::ActivationSoftsign))
171171
{
172172
desc.Type = DML_OPERATOR_ACTIVATION_SOFTSIGN;
173-
}
174-
else if (activationName == AttrValue::ActivationSoftplus)
173+
}
174+
else if (CompareActivationName(activationName, AttrValue::ActivationSoftplus))
175175
{
176176
desc.Type = DML_OPERATOR_ACTIVATION_SOFTPLUS;
177177
}
@@ -182,6 +182,12 @@ class DmlOperatorRecurrentBase: public DmlOperator, public RecurrentHelper
182182
}
183183
}
184184

185+
bool CompareActivationName(std::string_view activationName, std::string_view attrValue)
186+
{
187+
auto comparer = [](char a, char b) {return std::tolower(a) == std::tolower(b);};
188+
return std::equal(activationName.begin(), activationName.end(), attrValue.begin(), attrValue.end(), comparer);
189+
}
190+
185191
void Compute(const MLOperatorKernelContext& kernelContext) override
186192
{
187193
// Assume that enough GPU work has been queued up after the RNN operator that it is worth

0 commit comments

Comments
 (0)