@@ -2577,61 +2577,62 @@ pi_result cuda_piEnqueueKernelLaunch(
2577
2577
size_t maxThreadsPerBlock[3 ] = {};
2578
2578
bool providedLocalWorkGroupSize = (local_work_size != nullptr );
2579
2579
pi_uint32 local_size = kernel->get_local_size ();
2580
+ pi_result retError = PI_SUCCESS;
2580
2581
2581
- {
2582
- size_t *reqdThreadsPerBlock = kernel->reqdThreadsPerBlock_ ;
2583
- maxWorkGroupSize = command_queue->device_ ->get_max_work_group_size ();
2584
- command_queue->device_ ->get_max_work_item_sizes (sizeof (maxThreadsPerBlock),
2585
- maxThreadsPerBlock);
2586
-
2587
- if (providedLocalWorkGroupSize) {
2588
- auto isValid = [&](int dim) {
2589
- if (reqdThreadsPerBlock[dim] != 0 &&
2590
- local_work_size[dim] != reqdThreadsPerBlock[dim])
2591
- return PI_INVALID_WORK_GROUP_SIZE;
2592
-
2593
- if (local_work_size[dim] > maxThreadsPerBlock[dim])
2594
- return PI_INVALID_WORK_ITEM_SIZE;
2595
- // Checks that local work sizes are a divisor of the global work sizes
2596
- // which includes that the local work sizes are neither larger than the
2597
- // global work sizes and not 0.
2598
- if (0u == local_work_size[dim])
2599
- return PI_INVALID_WORK_GROUP_SIZE;
2600
- if (0u != (global_work_size[dim] % local_work_size[dim]))
2601
- return PI_INVALID_WORK_GROUP_SIZE;
2602
- threadsPerBlock[dim] = static_cast <int >(local_work_size[dim]);
2603
- return PI_SUCCESS;
2604
- };
2605
-
2606
- for (size_t dim = 0 ; dim < work_dim; dim++) {
2607
- auto err = isValid (dim);
2608
- if (err != PI_SUCCESS)
2609
- return err;
2582
+ try {
2583
+ // Set the active context here as guessLocalWorkSize needs an active context
2584
+ ScopedContext active (command_queue->get_context ());
2585
+ {
2586
+ size_t *reqdThreadsPerBlock = kernel->reqdThreadsPerBlock_ ;
2587
+ maxWorkGroupSize = command_queue->device_ ->get_max_work_group_size ();
2588
+ command_queue->device_ ->get_max_work_item_sizes (
2589
+ sizeof (maxThreadsPerBlock), maxThreadsPerBlock);
2590
+
2591
+ if (providedLocalWorkGroupSize) {
2592
+ auto isValid = [&](int dim) {
2593
+ if (reqdThreadsPerBlock[dim] != 0 &&
2594
+ local_work_size[dim] != reqdThreadsPerBlock[dim])
2595
+ return PI_INVALID_WORK_GROUP_SIZE;
2596
+
2597
+ if (local_work_size[dim] > maxThreadsPerBlock[dim])
2598
+ return PI_INVALID_WORK_ITEM_SIZE;
2599
+ // Checks that local work sizes are a divisor of the global work sizes
2600
+ // which includes that the local work sizes are neither larger than
2601
+ // the global work sizes and not 0.
2602
+ if (0u == local_work_size[dim])
2603
+ return PI_INVALID_WORK_GROUP_SIZE;
2604
+ if (0u != (global_work_size[dim] % local_work_size[dim]))
2605
+ return PI_INVALID_WORK_GROUP_SIZE;
2606
+ threadsPerBlock[dim] = static_cast <int >(local_work_size[dim]);
2607
+ return PI_SUCCESS;
2608
+ };
2609
+
2610
+ for (size_t dim = 0 ; dim < work_dim; dim++) {
2611
+ auto err = isValid (dim);
2612
+ if (err != PI_SUCCESS)
2613
+ return err;
2614
+ }
2615
+ } else {
2616
+ guessLocalWorkSize (threadsPerBlock, global_work_size,
2617
+ maxThreadsPerBlock, kernel, local_size);
2610
2618
}
2611
- } else {
2612
- guessLocalWorkSize (threadsPerBlock, global_work_size, maxThreadsPerBlock,
2613
- kernel, local_size);
2614
2619
}
2615
- }
2616
2620
2617
- if (maxWorkGroupSize <
2618
- size_t (threadsPerBlock[0 ] * threadsPerBlock[1 ] * threadsPerBlock[2 ])) {
2619
- return PI_INVALID_WORK_GROUP_SIZE;
2620
- }
2621
+ if (maxWorkGroupSize <
2622
+ size_t (threadsPerBlock[0 ] * threadsPerBlock[1 ] * threadsPerBlock[2 ])) {
2623
+ return PI_INVALID_WORK_GROUP_SIZE;
2624
+ }
2621
2625
2622
- int blocksPerGrid[3 ] = {1 , 1 , 1 };
2626
+ int blocksPerGrid[3 ] = {1 , 1 , 1 };
2623
2627
2624
- for (size_t i = 0 ; i < work_dim; i++) {
2625
- blocksPerGrid[i] =
2626
- static_cast <int >(global_work_size[i] + threadsPerBlock[i] - 1 ) /
2627
- threadsPerBlock[i];
2628
- }
2628
+ for (size_t i = 0 ; i < work_dim; i++) {
2629
+ blocksPerGrid[i] =
2630
+ static_cast <int >(global_work_size[i] + threadsPerBlock[i] - 1 ) /
2631
+ threadsPerBlock[i];
2632
+ }
2629
2633
2630
- pi_result retError = PI_SUCCESS;
2631
- std::unique_ptr<_pi_event> retImplEv{nullptr };
2634
+ std::unique_ptr<_pi_event> retImplEv{nullptr };
2632
2635
2633
- try {
2634
- ScopedContext active (command_queue->get_context ());
2635
2636
CUstream cuStream = command_queue->get ();
2636
2637
CUfunction cuFunc = kernel->get ();
2637
2638
0 commit comments