@@ -300,3 +300,60 @@ TEST_F(urDevicePartitionTest, SuccessSubSet) {
300
300
}
301
301
}
302
302
}
303
+
304
+ using urDevicePartitionByCountsTestWithParam =
305
+ urDevicePartitionTestWithParam<std::vector<size_t >>;
306
+ TEST_P (urDevicePartitionByCountsTestWithParam, CountsOrdering) {
307
+ ur_device_handle_t device = devices[0 ];
308
+
309
+ if (!uur::hasDevicePartitionSupport (device,
310
+ UR_DEVICE_PARTITION_BY_COUNTS)) {
311
+ GTEST_SKIP () << " Device \' " << device
312
+ << " \' does not support partitioning by counts\n " ;
313
+ }
314
+
315
+ auto requested_counts = GetParam ();
316
+
317
+ std::vector<ur_device_partition_property_t > property_list;
318
+ for (size_t i = 0 ; i < requested_counts.size (); ++i) {
319
+ property_list.push_back (
320
+ uur::makePartitionByCountsDesc (requested_counts[i]));
321
+ }
322
+
323
+ ur_device_partition_properties_t properties{
324
+ UR_STRUCTURE_TYPE_DEVICE_PARTITION_PROPERTIES, nullptr ,
325
+ property_list.data (), property_list.size ()};
326
+
327
+ uint32_t num_sub_devices = 0 ;
328
+ urDevicePartition (device, &properties, 0 , nullptr , &num_sub_devices);
329
+
330
+ std::vector<ur_device_handle_t > sub_devices (num_sub_devices);
331
+ urDevicePartition (device, &properties, num_sub_devices, sub_devices.data (),
332
+ nullptr );
333
+
334
+ std::vector<size_t > actual_counts;
335
+ for (const auto &sub_device : sub_devices) {
336
+ uint32_t n_compute_units = 0 ;
337
+ getNumberComputeUnits (sub_device, n_compute_units);
338
+ actual_counts.push_back (n_compute_units);
339
+ urDeviceRelease (sub_device);
340
+ }
341
+
342
+ ASSERT_EQ (requested_counts, actual_counts);
343
+ }
344
+
345
+ INSTANTIATE_TEST_SUITE_P (
346
+ , urDevicePartitionByCountsTestWithParam,
347
+ ::testing::Values (std::vector<size_t >{2 , 4 }, std::vector<size_t >{1 , 4 },
348
+ std::vector<size_t >{2 , 3 }, std::vector<size_t >{3 , 2 },
349
+ std::vector<size_t >{3 , 1 }),
350
+ [](const ::testing::TestParamInfo<std::vector<size_t >> &info) {
351
+ std::stringstream ss;
352
+ for (size_t i = 0 ; i < info.param .size (); ++i) {
353
+ if (i > 0 ) {
354
+ ss << " _" ;
355
+ }
356
+ ss << info.param [i];
357
+ }
358
+ return ss.str ();
359
+ });
0 commit comments