@@ -155,13 +155,14 @@ struct urEnqueueKernelLaunchIncrementTest
155
155
156
156
using Param = uur::BoolTestParam;
157
157
158
- using urMultiQueueLaunchMemcpyTest<numOps, Param>::context;
159
158
using urMultiQueueLaunchMemcpyTest<numOps, Param>::queues;
160
- using urMultiQueueLaunchMemcpyTest<numOps, Param>::devices;
161
159
using urMultiQueueLaunchMemcpyTest<numOps, Param>::kernels;
162
160
using urMultiQueueLaunchMemcpyTest<numOps, Param>::SharedMem;
163
161
164
162
void SetUp () override {
163
+ // We actually need a single device used multiple times for this test, as
164
+ // opposed to utilizing all available devices for the platform.
165
+ this ->trueMultiDevice = false ;
165
166
UUR_RETURN_ON_FATAL_FAILURE (
166
167
urMultiQueueLaunchMemcpyTest<numOps, Param>::
167
168
SetUp ()); // Use single device, duplicated numOps times
@@ -180,9 +181,6 @@ UUR_PLATFORM_TEST_SUITE_WITH_PARAM(
180
181
181
182
TEST_P (urEnqueueKernelLaunchIncrementTest, Success) {
182
183
UUR_KNOWN_FAILURE_ON (uur::LevelZeroV2{});
183
- if (devices.size () > 1 ) {
184
- UUR_KNOWN_FAILURE_ON (uur::CUDA{});
185
- }
186
184
187
185
constexpr size_t global_offset = 0 ;
188
186
constexpr size_t n_dimensions = 1 ;
@@ -347,9 +345,28 @@ TEST_P(urEnqueueKernelLaunchIncrementMultiDeviceTest, Success) {
347
345
}
348
346
}
349
347
350
- using urEnqueueKernelLaunchIncrementMultiDeviceMultiThreadTest =
351
- urEnqueueKernelLaunchIncrementMultiDeviceTestWithParam<
352
- std::tuple<uur::BoolTestParam, uur::BoolTestParam>>;
348
+ struct urEnqueueKernelLaunchIncrementMultiDeviceMultiThreadTest
349
+ : urEnqueueKernelLaunchIncrementMultiDeviceTestWithParam<
350
+ std::tuple<uur::BoolTestParam, uur::BoolTestParam>> {
351
+ using Param = std::tuple<uur::BoolTestParam, uur::BoolTestParam>;
352
+
353
+ using urEnqueueKernelLaunchIncrementMultiDeviceTestWithParam<Param>::devices;
354
+ using urEnqueueKernelLaunchIncrementMultiDeviceTestWithParam<Param>::queues;
355
+ using urEnqueueKernelLaunchIncrementMultiDeviceTestWithParam<Param>::kernels;
356
+ using urEnqueueKernelLaunchIncrementMultiDeviceTestWithParam<
357
+ Param>::SharedMem;
358
+
359
+ void SetUp () override {
360
+ useEvents = std::get<0 >(getParam ()).value ;
361
+ queuePerThread = std::get<1 >(getParam ()).value ;
362
+ // With !queuePerThread this becomes a test on a single device
363
+ this ->trueMultiDevice = queuePerThread;
364
+ urEnqueueKernelLaunchIncrementMultiDeviceTestWithParam<Param>::SetUp ();
365
+ }
366
+
367
+ bool useEvents;
368
+ bool queuePerThread;
369
+ };
353
370
354
371
UUR_PLATFORM_TEST_SUITE_WITH_PARAM (
355
372
urEnqueueKernelLaunchIncrementMultiDeviceMultiThreadTest,
@@ -359,14 +376,7 @@ UUR_PLATFORM_TEST_SUITE_WITH_PARAM(
359
376
printParams<urEnqueueKernelLaunchIncrementMultiDeviceMultiThreadTest>);
360
377
361
378
// Enqueue kernelLaunch concurrently from multiple threads
362
- // With !queuePerThread this becomes a test on a single device
363
379
TEST_P (urEnqueueKernelLaunchIncrementMultiDeviceMultiThreadTest, Success) {
364
- if (devices.size () > 1 ) {
365
- UUR_KNOWN_FAILURE_ON (uur::CUDA{});
366
- }
367
- auto useEvents = std::get<0 >(getParam ()).value ;
368
- auto queuePerThread = std::get<1 >(getParam ()).value ;
369
-
370
380
if (!queuePerThread) {
371
381
UUR_KNOWN_FAILURE_ON (uur::LevelZero{}, uur::LevelZeroV2{});
372
382
}
@@ -377,11 +387,11 @@ TEST_P(urEnqueueKernelLaunchIncrementMultiDeviceMultiThreadTest, Success) {
377
387
static constexpr size_t numOpsPerThread = 6 ;
378
388
379
389
for (size_t i = 0 ; i < numThreads; i++) {
380
- threads.emplace_back ([this , i, queuePerThread, useEvents ]() {
390
+ threads.emplace_back ([this , i]() {
381
391
constexpr size_t global_offset = 0 ;
382
392
constexpr size_t n_dimensions = 1 ;
383
393
384
- auto queue = queuePerThread ? queues[i] : queues.back ();
394
+ auto queue = this -> queuePerThread ? queues[i] : queues.back ();
385
395
auto kernel = kernels[i];
386
396
auto sharedPtr = SharedMem[i];
387
397
@@ -391,7 +401,7 @@ TEST_P(urEnqueueKernelLaunchIncrementMultiDeviceMultiThreadTest, Success) {
391
401
ur_event_handle_t *lastEvent = nullptr ;
392
402
ur_event_handle_t *signalEvent = nullptr ;
393
403
394
- if (useEvents) {
404
+ if (this -> useEvents ) {
395
405
waitNum = j > 0 ? 1 : 0 ;
396
406
lastEvent = j > 0 ? Events[j - 1 ].ptr () : nullptr ;
397
407
signalEvent = Events[j].ptr ();
0 commit comments