Skip to content

Commit 42399df

Browse files
authored
Fix a potential race in the CUDA TopK kernel (#19917)
### Description If the `K` value is flowing through as a tensor, we are updating a mutable member of the `TopK` class and basing the compute off that - which is likely to cause data race issues with concurrent Run() calls and `K` value changes. ### Motivation and Context Fix potential race in CUDA TopK kernel
1 parent bcf47d3 commit 42399df

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

Diff for: onnxruntime/core/providers/cuda/math/topk.cc

+17-7
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ TopK<inputk>::TopK(const OpKernelInfo& info) : CudaKernel(info) {
5656
info.GetAttrOrDefault<int64_t>("largest", &largest_, 1);
5757
info.GetAttrOrDefault<int64_t>("sorted", &sorted_, 1);
5858
if (!inputk) {
59-
info.GetAttrOrDefault<int64_t>("k", &K_, 0);
59+
info.GetAttrOrDefault<int64_t>("k", &attr_k_, 0);
6060
}
6161
}
6262

@@ -67,7 +67,7 @@ TopK<inputk>::TopK(const OpKernelInfo& info) : CudaKernel(info) {
6767
static_cast<int64_t*>(tensor_I->MutableDataRaw()), \
6868
elem_nums_cuda, \
6969
elem_nums.size(), \
70-
axis, K_, largest_, sorted_, N, dimension)
70+
axis, k_value, largest_, sorted_, N, dimension)
7171

7272
template <bool inputk>
7373
Status TopK<inputk>::ComputeInternal(OpKernelContext* ctx) const {
@@ -77,19 +77,29 @@ Status TopK<inputk>::ComputeInternal(OpKernelContext* ctx) const {
7777
int32_t axis = static_cast<int32_t>(axis_ < 0 ? rank + axis_ : axis_);
7878
ORT_ENFORCE(axis > -1 && axis < rank);
7979

80+
int64_t k_value = 0;
8081
if (inputk) {
8182
auto tensor_K = ctx->Input<Tensor>(1);
8283
ORT_ENFORCE(nullptr != tensor_K);
83-
K_ = *tensor_K->Data<int64_t>();
84-
ORT_ENFORCE(K_ >= 0 && K_ <= tensor_X->Shape().GetDims()[axis]);
84+
k_value = *tensor_K->Data<int64_t>();
85+
} else { // from attribute
86+
k_value = attr_k_;
8587
}
8688

87-
auto output_shape = tensor_X->Shape();
88-
output_shape[axis] = K_;
89+
// Now that we know the value of 'K' and the input shape,
90+
// make a final validation before going to the implementation
91+
const auto& input_shape = tensor_X->Shape();
92+
if ((k_value < 0) || (k_value > input_shape.GetDims()[axis])) {
93+
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Value of K outside range. K value: ", k_value,
94+
". Input shape: ", input_shape, " . Axis: ", axis);
95+
}
96+
97+
auto output_shape = input_shape;
98+
output_shape[axis] = k_value;
8999
auto tensor_V = ctx->Output(0, output_shape);
90100
auto tensor_I = ctx->Output(1, output_shape);
91101

92-
if (0 == K_) {
102+
if (output_shape.Size() == 0) { // Bail out early if the output is going to be empty
93103
return Status::OK();
94104
}
95105

Diff for: onnxruntime/core/providers/cuda/math/topk.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class TopK final : public CudaKernel {
1717
int64_t axis_;
1818
int64_t largest_;
1919
int64_t sorted_;
20-
mutable int64_t K_;
20+
int64_t attr_k_;
2121
};
2222
} // namespace cuda
2323
} // namespace onnxruntime

0 commit comments

Comments
 (0)