18
18
19
19
#include < cmath>
20
20
#include < cuda.h>
21
+ #include < ur/ur.hpp>
21
22
22
23
ur_result_t enqueueEventsWait (ur_queue_handle_t CommandQueue, CUstream Stream,
23
24
uint32_t NumEventsInWaitList,
@@ -144,8 +145,6 @@ void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock,
144
145
assert (ThreadsPerBlock != nullptr );
145
146
assert (GlobalWorkSize != nullptr );
146
147
assert (Kernel != nullptr );
147
- int MinGrid, MaxBlockSize;
148
- size_t MaxBlockDim[3 ];
149
148
150
149
// The below assumes a three dimensional range but this is not guaranteed by
151
150
// UR.
@@ -154,39 +153,31 @@ void guessLocalWorkSize(ur_device_handle_t Device, size_t *ThreadsPerBlock,
154
153
GlobalSizeNormalized[i] = GlobalWorkSize[i];
155
154
}
156
155
156
+ size_t MaxBlockDim[3 ];
157
+ MaxBlockDim[0 ] = MaxThreadsPerBlock[0 ];
157
158
MaxBlockDim[1 ] = Device->getMaxBlockDimY ();
158
159
MaxBlockDim[2 ] = Device->getMaxBlockDimZ ();
159
160
161
+ int MinGrid, MaxBlockSize;
160
162
UR_CHECK_ERROR (cuOccupancyMaxPotentialBlockSize (
161
163
&MinGrid, &MaxBlockSize, Kernel->get (), NULL , Kernel->getLocalSize (),
162
164
MaxThreadsPerBlock[0 ]));
163
165
164
- // Helper lambda to make sure each x, y, z dim divide the global dimension.
165
- // Can optionally specify that we want the wg size to be a power of 2 in a
166
- // given dimension, which is useful for the X dim for performance reasons.
167
- static auto roundToHighestFactorOfGlobalSize =
168
- [](size_t &ThreadsPerBlockInDim, const size_t GlobalWorkSizeInDim,
169
- bool MakePowerOfTwo) {
170
- auto IsPowerOf2 = [](size_t Value) -> bool {
171
- return Value && !(Value & (Value - 1 ));
172
- };
173
- while (GlobalWorkSizeInDim % ThreadsPerBlockInDim ||
174
- (MakePowerOfTwo && !IsPowerOf2 (ThreadsPerBlockInDim)))
175
- --ThreadsPerBlockInDim;
176
- };
177
-
178
166
ThreadsPerBlock[2 ] = std::min (GlobalSizeNormalized[2 ], MaxBlockDim[2 ]);
179
- roundToHighestFactorOfGlobalSize (ThreadsPerBlock[2 ], GlobalWorkSize[2 ],
180
- false );
167
+ roundToHighestFactorOfGlobalSize (ThreadsPerBlock[2 ], GlobalWorkSize[2 ]);
168
+
181
169
ThreadsPerBlock[1 ] =
182
170
std::min (GlobalSizeNormalized[1 ],
183
171
std::min (MaxBlockSize / ThreadsPerBlock[2 ], MaxBlockDim[1 ]));
184
- roundToHighestFactorOfGlobalSize (ThreadsPerBlock[1 ], GlobalWorkSize[1 ],
185
- false );
172
+ roundToHighestFactorOfGlobalSize (ThreadsPerBlock[1 ], GlobalWorkSize[1 ]);
173
+
186
174
MaxBlockDim[0 ] = MaxBlockSize / (ThreadsPerBlock[1 ] * ThreadsPerBlock[2 ]);
187
175
ThreadsPerBlock[0 ] = std::min (
188
176
MaxThreadsPerBlock[0 ], std::min (GlobalSizeNormalized[0 ], MaxBlockDim[0 ]));
189
- roundToHighestFactorOfGlobalSize (ThreadsPerBlock[0 ], GlobalWorkSize[0 ], true );
177
+ // Make the X dim a factor of 2
178
+ do {
179
+ roundToHighestFactorOfGlobalSize (ThreadsPerBlock[0 ], GlobalWorkSize[0 ]);
180
+ } while (!isPowerOf2 (ThreadsPerBlock[0 ]));
190
181
}
191
182
192
183
// Helper to verify out-of-registers case (exceeded block max registers).
0 commit comments