@@ -347,23 +347,32 @@ template <typename T> inline bool isPowerOf2(const T &Value) {
347
347
// dims == 1)
348
348
// In: MaxBlockDim - The max size of block in 3d
349
349
// In: MaxBlockSize - The max total size of block in all dimensions
350
+ // In: WorkDim - The workdim (1, 2 or 3)
350
351
static inline void roundToHighestFactorOfGlobalSizeIn3d (
351
352
size_t *ThreadsPerBlock, const size_t *GlobalSize,
352
- const size_t *MaxBlockDim, const size_t MaxBlockSize) {
353
+ const size_t *MaxBlockDim, const size_t MaxBlockSize,
354
+ const size_t WorkDim) {
353
355
ThreadsPerBlock[0 ] = std::min (GlobalSize[0 ], MaxBlockDim[0 ]);
354
356
// Make the X dim a factor of 2
355
357
do {
356
358
roundToHighestFactorOfGlobalSize (ThreadsPerBlock[0 ], GlobalSize[0 ]);
357
- } while (! isPowerOf2 (ThreadsPerBlock[ 0 ]) && ThreadsPerBlock[0 ] > 32 &&
358
- --ThreadsPerBlock[0 ]);
359
+ } while (WorkDim == 3 && ! isPowerOf2 ( ThreadsPerBlock[0 ]) &&
360
+ ThreadsPerBlock[ 0 ] > 32 && --ThreadsPerBlock[0 ]);
359
361
360
362
ThreadsPerBlock[1 ] =
361
363
std::min (GlobalSize[1 ],
362
364
std::min (MaxBlockSize / ThreadsPerBlock[0 ], MaxBlockDim[1 ]));
363
- roundToHighestFactorOfGlobalSize (ThreadsPerBlock[1 ], GlobalSize[1 ]);
365
+ do {
366
+ roundToHighestFactorOfGlobalSize (ThreadsPerBlock[1 ], GlobalSize[1 ]);
367
+ } while (WorkDim == 2 && !isPowerOf2 (ThreadsPerBlock[1 ]) &&
368
+ ThreadsPerBlock[1 ] > 32 && --ThreadsPerBlock[1 ]);
364
369
365
370
ThreadsPerBlock[2 ] = std::min (
366
- GlobalSize[2 ], MaxBlockSize / (ThreadsPerBlock[1 ] * ThreadsPerBlock[0 ]));
367
- roundToHighestFactorOfGlobalSize (ThreadsPerBlock[2 ], GlobalSize[2 ]);
368
-
371
+ GlobalSize[2 ],
372
+ std::min (MaxBlockSize / (ThreadsPerBlock[1 ] * ThreadsPerBlock[0 ]),
373
+ MaxBlockDim[2 ]));
374
+ do {
375
+ roundToHighestFactorOfGlobalSize (ThreadsPerBlock[2 ], GlobalSize[2 ]);
376
+ } while (WorkDim == 1 && !isPowerOf2 (ThreadsPerBlock[2 ]) &&
377
+ ThreadsPerBlock[2 ] > 32 && --ThreadsPerBlock[2 ]);
369
378
}
0 commit comments