Skip to content

Commit bb589ca

Browse files
authored
Merge pull request oneapi-src#1319 from Bensuo/maxime/cuda-large-fill-pattern
[EXP][CMDBUF] Improve CUDA Fill op implementation
2 parents ec634ff + ef72b3f commit bb589ca

File tree

1 file changed

+61
-16
lines changed

1 file changed

+61
-16
lines changed

source/adapters/cuda/command_buffer.cpp

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,6 @@ static ur_result_t enqueueCommandBufferFillHelper(
170170

171171
try {
172172
const size_t N = Size / PatternSize;
173-
auto Value = *static_cast<const uint32_t *>(Pattern);
174173
auto DstPtr = DstType == CU_MEMORYTYPE_DEVICE
175174
? *static_cast<CUdeviceptr *>(DstDevice)
176175
: (CUdeviceptr)DstDevice;
@@ -183,9 +182,27 @@ static ur_result_t enqueueCommandBufferFillHelper(
183182
NodeParams.elementSize = PatternSize;
184183
NodeParams.height = N;
185184
NodeParams.pitch = PatternSize;
186-
NodeParams.value = Value;
187185
NodeParams.width = 1;
188186

187+
// pattern size in bytes
188+
switch (PatternSize) {
189+
case 1: {
190+
auto Value = *static_cast<const uint8_t *>(Pattern);
191+
NodeParams.value = Value;
192+
break;
193+
}
194+
case 2: {
195+
auto Value = *static_cast<const uint16_t *>(Pattern);
196+
NodeParams.value = Value;
197+
break;
198+
}
199+
case 4: {
200+
auto Value = *static_cast<const uint32_t *>(Pattern);
201+
NodeParams.value = Value;
202+
break;
203+
}
204+
}
205+
189206
UR_CHECK_ERROR(cuGraphAddMemsetNode(
190207
&GraphNode, CommandBuffer->CudaGraph, DepsList.data(),
191208
DepsList.size(), &NodeParams, CommandBuffer->Device->getContext()));
@@ -198,29 +215,54 @@ static ur_result_t enqueueCommandBufferFillHelper(
198215
// CUDA has no memset functions that allow setting values more than 4
199216
// bytes. UR API lets you pass an arbitrary "pattern" to the buffer
200217
// fill, which can be more than 4 bytes. We must break up the pattern
201-
// into 4 byte values, and set the buffer using multiple strided calls.
202-
// This means that one cuGraphAddMemsetNode call is made for every 4 bytes
203-
// in the pattern.
218+
// into 1 byte values, and set the buffer using multiple strided calls.
219+
// This means that one cuGraphAddMemsetNode call is made for every 1
220+
// bytes in the pattern.
221+
222+
size_t NumberOfSteps = PatternSize / sizeof(uint8_t);
204223

205-
size_t NumberOfSteps = PatternSize / sizeof(uint32_t);
224+
// Shared pointer that will point to the last node created
225+
std::shared_ptr<CUgraphNode> GraphNodePtr;
226+
// Create a new node
227+
CUgraphNode GraphNodeFirst;
228+
// Update NodeParam
229+
CUDA_MEMSET_NODE_PARAMS NodeParamsStepFirst = {};
230+
NodeParamsStepFirst.dst = DstPtr;
231+
NodeParamsStepFirst.elementSize = sizeof(uint32_t);
232+
NodeParamsStepFirst.height = Size / sizeof(uint32_t);
233+
NodeParamsStepFirst.pitch = sizeof(uint32_t);
234+
NodeParamsStepFirst.value = *static_cast<const uint32_t *>(Pattern);
235+
NodeParamsStepFirst.width = 1;
206236

207-
// we walk up the pattern in 4-byte steps, and call cuMemset for each
208-
// 4-byte chunk of the pattern.
209-
for (auto Step = 0u; Step < NumberOfSteps; ++Step) {
237+
UR_CHECK_ERROR(cuGraphAddMemsetNode(
238+
&GraphNodeFirst, CommandBuffer->CudaGraph, DepsList.data(),
239+
DepsList.size(), &NodeParamsStepFirst,
240+
CommandBuffer->Device->getContext()));
241+
242+
// Get sync point and register the cuNode with it.
243+
*SyncPoint = CommandBuffer->addSyncPoint(
244+
std::make_shared<CUgraphNode>(GraphNodeFirst));
245+
246+
DepsList.clear();
247+
DepsList.push_back(GraphNodeFirst);
248+
249+
// we walk up the pattern in 1-byte steps, and call cuMemset for each
250+
// 1-byte chunk of the pattern.
251+
for (auto Step = 4u; Step < NumberOfSteps; ++Step) {
210252
// take 4 bytes of the pattern
211-
auto Value = *(static_cast<const uint32_t *>(Pattern) + Step);
253+
auto Value = *(static_cast<const uint8_t *>(Pattern) + Step);
212254

213255
// offset the pointer to the part of the buffer we want to write to
214-
auto OffsetPtr = DstPtr + (Step * sizeof(uint32_t));
256+
auto OffsetPtr = DstPtr + (Step * sizeof(uint8_t));
215257

216258
// Create a new node
217259
CUgraphNode GraphNode;
218260
// Update NodeParam
219261
CUDA_MEMSET_NODE_PARAMS NodeParamsStep = {};
220262
NodeParamsStep.dst = (CUdeviceptr)OffsetPtr;
221-
NodeParamsStep.elementSize = 4;
222-
NodeParamsStep.height = N;
223-
NodeParamsStep.pitch = PatternSize;
263+
NodeParamsStep.elementSize = sizeof(uint8_t);
264+
NodeParamsStep.height = Size / NumberOfSteps;
265+
NodeParamsStep.pitch = NumberOfSteps * sizeof(uint8_t);
224266
NodeParamsStep.value = Value;
225267
NodeParamsStep.width = 1;
226268

@@ -229,9 +271,12 @@ static ur_result_t enqueueCommandBufferFillHelper(
229271
DepsList.size(), &NodeParamsStep,
230272
CommandBuffer->Device->getContext()));
231273

274+
GraphNodePtr = std::make_shared<CUgraphNode>(GraphNode);
232275
// Get sync point and register the cuNode with it.
233-
*SyncPoint = CommandBuffer->addSyncPoint(
234-
std::make_shared<CUgraphNode>(GraphNode));
276+
*SyncPoint = CommandBuffer->addSyncPoint(GraphNodePtr);
277+
278+
DepsList.clear();
279+
DepsList.push_back(*GraphNodePtr.get());
235280
}
236281
}
237282
} catch (ur_result_t Err) {

0 commit comments

Comments
 (0)