Skip to content

Commit 0d3cc99

Browse files
authored
[SYCL][CUDA] Fix context scope in kernel launch (#4606)
The `guessLocalWorkSize` function uses the CUDA API so it needs an active context, and there was no active `ScopedContext` when it was called which may cause issue.
1 parent ce7725d commit 0d3cc99

File tree

1 file changed

+48
-47
lines changed

1 file changed

+48
-47
lines changed

sycl/plugins/cuda/pi_cuda.cpp

+48-47
Original file line numberDiff line numberDiff line change
@@ -2577,61 +2577,62 @@ pi_result cuda_piEnqueueKernelLaunch(
25772577
size_t maxThreadsPerBlock[3] = {};
25782578
bool providedLocalWorkGroupSize = (local_work_size != nullptr);
25792579
pi_uint32 local_size = kernel->get_local_size();
2580+
pi_result retError = PI_SUCCESS;
25802581

2581-
{
2582-
size_t *reqdThreadsPerBlock = kernel->reqdThreadsPerBlock_;
2583-
maxWorkGroupSize = command_queue->device_->get_max_work_group_size();
2584-
command_queue->device_->get_max_work_item_sizes(sizeof(maxThreadsPerBlock),
2585-
maxThreadsPerBlock);
2586-
2587-
if (providedLocalWorkGroupSize) {
2588-
auto isValid = [&](int dim) {
2589-
if (reqdThreadsPerBlock[dim] != 0 &&
2590-
local_work_size[dim] != reqdThreadsPerBlock[dim])
2591-
return PI_INVALID_WORK_GROUP_SIZE;
2592-
2593-
if (local_work_size[dim] > maxThreadsPerBlock[dim])
2594-
return PI_INVALID_WORK_ITEM_SIZE;
2595-
// Checks that local work sizes are a divisor of the global work sizes
2596-
// which includes that the local work sizes are neither larger than the
2597-
// global work sizes and not 0.
2598-
if (0u == local_work_size[dim])
2599-
return PI_INVALID_WORK_GROUP_SIZE;
2600-
if (0u != (global_work_size[dim] % local_work_size[dim]))
2601-
return PI_INVALID_WORK_GROUP_SIZE;
2602-
threadsPerBlock[dim] = static_cast<int>(local_work_size[dim]);
2603-
return PI_SUCCESS;
2604-
};
2605-
2606-
for (size_t dim = 0; dim < work_dim; dim++) {
2607-
auto err = isValid(dim);
2608-
if (err != PI_SUCCESS)
2609-
return err;
2582+
try {
2583+
// Set the active context here as guessLocalWorkSize needs an active context
2584+
ScopedContext active(command_queue->get_context());
2585+
{
2586+
size_t *reqdThreadsPerBlock = kernel->reqdThreadsPerBlock_;
2587+
maxWorkGroupSize = command_queue->device_->get_max_work_group_size();
2588+
command_queue->device_->get_max_work_item_sizes(
2589+
sizeof(maxThreadsPerBlock), maxThreadsPerBlock);
2590+
2591+
if (providedLocalWorkGroupSize) {
2592+
auto isValid = [&](int dim) {
2593+
if (reqdThreadsPerBlock[dim] != 0 &&
2594+
local_work_size[dim] != reqdThreadsPerBlock[dim])
2595+
return PI_INVALID_WORK_GROUP_SIZE;
2596+
2597+
if (local_work_size[dim] > maxThreadsPerBlock[dim])
2598+
return PI_INVALID_WORK_ITEM_SIZE;
2599+
// Checks that local work sizes are a divisor of the global work sizes
2600+
// which includes that the local work sizes are neither larger than
2601+
// the global work sizes and not 0.
2602+
if (0u == local_work_size[dim])
2603+
return PI_INVALID_WORK_GROUP_SIZE;
2604+
if (0u != (global_work_size[dim] % local_work_size[dim]))
2605+
return PI_INVALID_WORK_GROUP_SIZE;
2606+
threadsPerBlock[dim] = static_cast<int>(local_work_size[dim]);
2607+
return PI_SUCCESS;
2608+
};
2609+
2610+
for (size_t dim = 0; dim < work_dim; dim++) {
2611+
auto err = isValid(dim);
2612+
if (err != PI_SUCCESS)
2613+
return err;
2614+
}
2615+
} else {
2616+
guessLocalWorkSize(threadsPerBlock, global_work_size,
2617+
maxThreadsPerBlock, kernel, local_size);
26102618
}
2611-
} else {
2612-
guessLocalWorkSize(threadsPerBlock, global_work_size, maxThreadsPerBlock,
2613-
kernel, local_size);
26142619
}
2615-
}
26162620

2617-
if (maxWorkGroupSize <
2618-
size_t(threadsPerBlock[0] * threadsPerBlock[1] * threadsPerBlock[2])) {
2619-
return PI_INVALID_WORK_GROUP_SIZE;
2620-
}
2621+
if (maxWorkGroupSize <
2622+
size_t(threadsPerBlock[0] * threadsPerBlock[1] * threadsPerBlock[2])) {
2623+
return PI_INVALID_WORK_GROUP_SIZE;
2624+
}
26212625

2622-
int blocksPerGrid[3] = {1, 1, 1};
2626+
int blocksPerGrid[3] = {1, 1, 1};
26232627

2624-
for (size_t i = 0; i < work_dim; i++) {
2625-
blocksPerGrid[i] =
2626-
static_cast<int>(global_work_size[i] + threadsPerBlock[i] - 1) /
2627-
threadsPerBlock[i];
2628-
}
2628+
for (size_t i = 0; i < work_dim; i++) {
2629+
blocksPerGrid[i] =
2630+
static_cast<int>(global_work_size[i] + threadsPerBlock[i] - 1) /
2631+
threadsPerBlock[i];
2632+
}
26292633

2630-
pi_result retError = PI_SUCCESS;
2631-
std::unique_ptr<_pi_event> retImplEv{nullptr};
2634+
std::unique_ptr<_pi_event> retImplEv{nullptr};
26322635

2633-
try {
2634-
ScopedContext active(command_queue->get_context());
26352636
CUstream cuStream = command_queue->get();
26362637
CUfunction cuFunc = kernel->get();
26372638

0 commit comments

Comments
 (0)