@@ -48,9 +48,9 @@ commandHandleReleaseInternal(ur_exp_command_buffer_command_handle_t Command) {
48
48
49
49
ur_exp_command_buffer_handle_t_::ur_exp_command_buffer_handle_t_ (
50
50
ur_context_handle_t hContext, ur_device_handle_t hDevice, bool IsUpdatable)
51
- : Context(hContext), Device(hDevice), IsUpdatable(IsUpdatable),
52
- HIPGraph{nullptr }, HIPGraphExec{nullptr }, RefCountInternal{ 1 },
53
- RefCountExternal{1 } {
51
+ : Context(hContext), Device(hDevice),
52
+ IsUpdatable(IsUpdatable), HIPGraph{nullptr }, HIPGraphExec{nullptr },
53
+ RefCountInternal{ 1 }, RefCountExternal{1 } {
54
54
urContextRetain (hContext);
55
55
urDeviceRetain (hDevice);
56
56
}
@@ -155,7 +155,6 @@ static ur_result_t enqueueCommandBufferFillHelper(
155
155
156
156
try {
157
157
const size_t N = Size / PatternSize;
158
- auto Value = *static_cast <const uint32_t *>(Pattern);
159
158
auto DstPtr = DstType == hipMemoryTypeDevice
160
159
? *static_cast <hipDeviceptr_t *>(DstDevice)
161
160
: DstDevice;
@@ -168,9 +167,27 @@ static ur_result_t enqueueCommandBufferFillHelper(
168
167
NodeParams.elementSize = PatternSize;
169
168
NodeParams.height = N;
170
169
NodeParams.pitch = PatternSize;
171
- NodeParams.value = Value;
172
170
NodeParams.width = 1 ;
173
171
172
+ // pattern size in bytes
173
+ switch (PatternSize) {
174
+ case 1 : {
175
+ auto Value = *static_cast <const uint8_t *>(Pattern);
176
+ NodeParams.value = Value;
177
+ break ;
178
+ }
179
+ case 2 : {
180
+ auto Value = *static_cast <const uint16_t *>(Pattern);
181
+ NodeParams.value = Value;
182
+ break ;
183
+ }
184
+ case 4 : {
185
+ auto Value = *static_cast <const uint32_t *>(Pattern);
186
+ NodeParams.value = Value;
187
+ break ;
188
+ }
189
+ }
190
+
174
191
UR_CHECK_ERROR (hipGraphAddMemsetNode (&GraphNode, CommandBuffer->HIPGraph ,
175
192
DepsList.data (), DepsList.size (),
176
193
&NodeParams));
@@ -187,15 +204,10 @@ static ur_result_t enqueueCommandBufferFillHelper(
187
204
// This means that one hipGraphAddMemsetNode call is made for every 1
188
205
// bytes in the pattern.
189
206
190
- // List to handle inter-node dependencies
191
- std::vector<hipGraphNode_t> HIPNodesList = {};
192
- // List shared pointer that will point to the last node created
193
- std::shared_ptr<hipGraphNode_t> GraphNodePtr;
194
-
195
207
size_t NumberOfSteps = PatternSize / sizeof (uint8_t );
196
208
197
- // take 4 bytes of the pattern
198
- auto ValueFirst = *( static_cast < const uint32_t *>(Pattern)) ;
209
+ // Shared pointer that will point to the last node created
210
+ std::shared_ptr<hipGraphNode_t> GraphNodePtr ;
199
211
200
212
// Create a new node
201
213
hipGraphNode_t GraphNodeFirst;
@@ -205,7 +217,7 @@ static ur_result_t enqueueCommandBufferFillHelper(
205
217
NodeParamsStepFirst.elementSize = 4 ;
206
218
NodeParamsStepFirst.height = Size / sizeof (uint32_t );
207
219
NodeParamsStepFirst.pitch = 4 ;
208
- NodeParamsStepFirst.value = ValueFirst ;
220
+ NodeParamsStepFirst.value = *( static_cast < const uint32_t *>(Pattern)) ;
209
221
NodeParamsStepFirst.width = 1 ;
210
222
211
223
UR_CHECK_ERROR (hipGraphAddMemsetNode (
@@ -216,7 +228,8 @@ static ur_result_t enqueueCommandBufferFillHelper(
216
228
*SyncPoint = CommandBuffer->addSyncPoint (
217
229
std::make_shared<hipGraphNode_t>(GraphNodeFirst));
218
230
219
- HIPNodesList.push_back (GraphNodeFirst);
231
+ DepsList.clear ();
232
+ DepsList.push_back (GraphNodeFirst);
220
233
221
234
// we walk up the pattern in 1-byte steps, and add Memset node for each
222
235
// 1-byte chunk of the pattern.
@@ -233,22 +246,22 @@ static ur_result_t enqueueCommandBufferFillHelper(
233
246
// Update NodeParam
234
247
hipMemsetParams NodeParamsStep = {};
235
248
NodeParamsStep.dst = reinterpret_cast <void *>(OffsetPtr);
236
- NodeParamsStep.elementSize = 1 ;
249
+ NodeParamsStep.elementSize = sizeof ( uint8_t ) ;
237
250
NodeParamsStep.height = Size / NumberOfSteps;
238
251
NodeParamsStep.pitch = NumberOfSteps * sizeof (uint8_t );
239
252
NodeParamsStep.value = Value;
240
253
NodeParamsStep.width = 1 ;
241
254
242
255
UR_CHECK_ERROR (hipGraphAddMemsetNode (
243
- &GraphNode, CommandBuffer->HIPGraph , HIPNodesList .data (),
244
- HIPNodesList .size (), &NodeParamsStep));
256
+ &GraphNode, CommandBuffer->HIPGraph , DepsList .data (),
257
+ DepsList .size (), &NodeParamsStep));
245
258
246
259
GraphNodePtr = std::make_shared<hipGraphNode_t>(GraphNode);
247
260
// Get sync point and register the node with it.
248
261
*SyncPoint = CommandBuffer->addSyncPoint (GraphNodePtr);
249
262
250
- HIPNodesList .clear ();
251
- HIPNodesList .push_back (*GraphNodePtr.get ());
263
+ DepsList .clear ();
264
+ DepsList .push_back (*GraphNodePtr.get ());
252
265
}
253
266
}
254
267
} catch (ur_result_t Err) {
0 commit comments