@@ -351,28 +351,18 @@ template <typename T> inline bool isPowerOf2(const T &Value) {
351
351
static inline void roundToHighestFactorOfGlobalSizeIn3d (
352
352
size_t *ThreadsPerBlock, const size_t *GlobalSize,
353
353
const size_t *MaxBlockDim, const size_t MaxBlockSize,
354
- const size_t WorkDim ) {
354
+ const size_t ) {
355
355
ThreadsPerBlock[0 ] = std::min (GlobalSize[0 ], MaxBlockDim[0 ]);
356
- // Make the X dim a factor of 2
357
- do {
358
- roundToHighestFactorOfGlobalSize (ThreadsPerBlock[0 ], GlobalSize[0 ]);
359
- } while (WorkDim == 3 && !isPowerOf2 (ThreadsPerBlock[0 ]) &&
360
- ThreadsPerBlock[0 ] > 32 && --ThreadsPerBlock[0 ]);
356
+ roundToHighestFactorOfGlobalSize (ThreadsPerBlock[0 ], GlobalSize[0 ]);
361
357
362
358
ThreadsPerBlock[1 ] =
363
359
std::min (GlobalSize[1 ],
364
360
std::min (MaxBlockSize / ThreadsPerBlock[0 ], MaxBlockDim[1 ]));
365
- do {
366
- roundToHighestFactorOfGlobalSize (ThreadsPerBlock[1 ], GlobalSize[1 ]);
367
- } while (WorkDim == 2 && !isPowerOf2 (ThreadsPerBlock[1 ]) &&
368
- ThreadsPerBlock[1 ] > 32 && --ThreadsPerBlock[1 ]);
361
+ roundToHighestFactorOfGlobalSize (ThreadsPerBlock[1 ], GlobalSize[1 ]);
369
362
370
363
ThreadsPerBlock[2 ] = std::min (
371
364
GlobalSize[2 ],
372
365
std::min (MaxBlockSize / (ThreadsPerBlock[1 ] * ThreadsPerBlock[0 ]),
373
366
MaxBlockDim[2 ]));
374
- do {
375
- roundToHighestFactorOfGlobalSize (ThreadsPerBlock[2 ], GlobalSize[2 ]);
376
- } while (WorkDim == 1 && !isPowerOf2 (ThreadsPerBlock[2 ]) &&
377
- ThreadsPerBlock[2 ] > 32 && --ThreadsPerBlock[2 ]);
367
+ roundToHighestFactorOfGlobalSize (ThreadsPerBlock[2 ], GlobalSize[2 ]);
378
368
}
0 commit comments