@@ -170,7 +170,6 @@ static ur_result_t enqueueCommandBufferFillHelper(
170
170
171
171
try {
172
172
const size_t N = Size / PatternSize;
173
- auto Value = *static_cast <const uint32_t *>(Pattern);
174
173
auto DstPtr = DstType == CU_MEMORYTYPE_DEVICE
175
174
? *static_cast <CUdeviceptr *>(DstDevice)
176
175
: (CUdeviceptr)DstDevice;
@@ -183,9 +182,27 @@ static ur_result_t enqueueCommandBufferFillHelper(
183
182
NodeParams.elementSize = PatternSize;
184
183
NodeParams.height = N;
185
184
NodeParams.pitch = PatternSize;
186
- NodeParams.value = Value;
187
185
NodeParams.width = 1 ;
188
186
187
+ // pattern size in bytes
188
+ switch (PatternSize) {
189
+ case 1 : {
190
+ auto Value = *static_cast <const uint8_t *>(Pattern);
191
+ NodeParams.value = Value;
192
+ break ;
193
+ }
194
+ case 2 : {
195
+ auto Value = *static_cast <const uint16_t *>(Pattern);
196
+ NodeParams.value = Value;
197
+ break ;
198
+ }
199
+ case 4 : {
200
+ auto Value = *static_cast <const uint32_t *>(Pattern);
201
+ NodeParams.value = Value;
202
+ break ;
203
+ }
204
+ }
205
+
189
206
UR_CHECK_ERROR (cuGraphAddMemsetNode (
190
207
&GraphNode, CommandBuffer->CudaGraph , DepsList.data (),
191
208
DepsList.size (), &NodeParams, CommandBuffer->Device ->getContext ()));
@@ -198,29 +215,54 @@ static ur_result_t enqueueCommandBufferFillHelper(
198
215
// CUDA has no memset functions that allow setting values more than 4
199
216
// bytes. UR API lets you pass an arbitrary "pattern" to the buffer
200
217
// fill, which can be more than 4 bytes. We must break up the pattern
201
- // into 4 byte values, and set the buffer using multiple strided calls.
202
- // This means that one cuGraphAddMemsetNode call is made for every 4 bytes
203
- // in the pattern.
218
+ // into 1 byte values, and set the buffer using multiple strided calls.
219
+ // This means that one cuGraphAddMemsetNode call is made for every 1
220
+ // bytes in the pattern.
221
+
222
+ size_t NumberOfSteps = PatternSize / sizeof (uint8_t );
204
223
205
- size_t NumberOfSteps = PatternSize / sizeof (uint32_t );
224
+ // Shared pointer that will point to the last node created
225
+ std::shared_ptr<CUgraphNode> GraphNodePtr;
226
+ // Create a new node
227
+ CUgraphNode GraphNodeFirst;
228
+ // Update NodeParam
229
+ CUDA_MEMSET_NODE_PARAMS NodeParamsStepFirst = {};
230
+ NodeParamsStepFirst.dst = DstPtr;
231
+ NodeParamsStepFirst.elementSize = sizeof (uint32_t );
232
+ NodeParamsStepFirst.height = Size / sizeof (uint32_t );
233
+ NodeParamsStepFirst.pitch = sizeof (uint32_t );
234
+ NodeParamsStepFirst.value = *static_cast <const uint32_t *>(Pattern);
235
+ NodeParamsStepFirst.width = 1 ;
206
236
207
- // we walk up the pattern in 4-byte steps, and call cuMemset for each
208
- // 4-byte chunk of the pattern.
209
- for (auto Step = 0u ; Step < NumberOfSteps; ++Step) {
237
+ UR_CHECK_ERROR (cuGraphAddMemsetNode (
238
+ &GraphNodeFirst, CommandBuffer->CudaGraph , DepsList.data (),
239
+ DepsList.size (), &NodeParamsStepFirst,
240
+ CommandBuffer->Device ->getContext ()));
241
+
242
+ // Get sync point and register the cuNode with it.
243
+ *SyncPoint = CommandBuffer->addSyncPoint (
244
+ std::make_shared<CUgraphNode>(GraphNodeFirst));
245
+
246
+ DepsList.clear ();
247
+ DepsList.push_back (GraphNodeFirst);
248
+
249
+ // we walk up the pattern in 1-byte steps, and call cuMemset for each
250
+ // 1-byte chunk of the pattern.
251
+ for (auto Step = 4u ; Step < NumberOfSteps; ++Step) {
210
252
// take 4 bytes of the pattern
211
- auto Value = *(static_cast <const uint32_t *>(Pattern) + Step);
253
+ auto Value = *(static_cast <const uint8_t *>(Pattern) + Step);
212
254
213
255
// offset the pointer to the part of the buffer we want to write to
214
- auto OffsetPtr = DstPtr + (Step * sizeof (uint32_t ));
256
+ auto OffsetPtr = DstPtr + (Step * sizeof (uint8_t ));
215
257
216
258
// Create a new node
217
259
CUgraphNode GraphNode;
218
260
// Update NodeParam
219
261
CUDA_MEMSET_NODE_PARAMS NodeParamsStep = {};
220
262
NodeParamsStep.dst = (CUdeviceptr)OffsetPtr;
221
- NodeParamsStep.elementSize = 4 ;
222
- NodeParamsStep.height = N ;
223
- NodeParamsStep.pitch = PatternSize ;
263
+ NodeParamsStep.elementSize = sizeof ( uint8_t ) ;
264
+ NodeParamsStep.height = Size / NumberOfSteps ;
265
+ NodeParamsStep.pitch = NumberOfSteps * sizeof ( uint8_t ) ;
224
266
NodeParamsStep.value = Value;
225
267
NodeParamsStep.width = 1 ;
226
268
@@ -229,9 +271,12 @@ static ur_result_t enqueueCommandBufferFillHelper(
229
271
DepsList.size (), &NodeParamsStep,
230
272
CommandBuffer->Device ->getContext ()));
231
273
274
+ GraphNodePtr = std::make_shared<CUgraphNode>(GraphNode);
232
275
// Get sync point and register the cuNode with it.
233
- *SyncPoint = CommandBuffer->addSyncPoint (
234
- std::make_shared<CUgraphNode>(GraphNode));
276
+ *SyncPoint = CommandBuffer->addSyncPoint (GraphNodePtr);
277
+
278
+ DepsList.clear ();
279
+ DepsList.push_back (*GraphNodePtr.get ());
235
280
}
236
281
}
237
282
} catch (ur_result_t Err) {
0 commit comments