Skip to content

Commit c9d1431

Browse files
committed
Improve fill op implementation
Match the CUDA change from #1319 in HIP.
1 parent 25b0843 commit c9d1431

File tree

2 files changed

+33
-21
lines changed

2 files changed

+33
-21
lines changed

source/adapters/hip/command_buffer.cpp

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ commandHandleReleaseInternal(ur_exp_command_buffer_command_handle_t Command) {
4848

4949
ur_exp_command_buffer_handle_t_::ur_exp_command_buffer_handle_t_(
5050
ur_context_handle_t hContext, ur_device_handle_t hDevice, bool IsUpdatable)
51-
: Context(hContext), Device(hDevice), IsUpdatable(IsUpdatable),
52-
HIPGraph{nullptr}, HIPGraphExec{nullptr}, RefCountInternal{1},
53-
RefCountExternal{1} {
51+
: Context(hContext), Device(hDevice),
52+
IsUpdatable(IsUpdatable), HIPGraph{nullptr}, HIPGraphExec{nullptr},
53+
RefCountInternal{1}, RefCountExternal{1} {
5454
urContextRetain(hContext);
5555
urDeviceRetain(hDevice);
5656
}
@@ -155,7 +155,6 @@ static ur_result_t enqueueCommandBufferFillHelper(
155155

156156
try {
157157
const size_t N = Size / PatternSize;
158-
auto Value = *static_cast<const uint32_t *>(Pattern);
159158
auto DstPtr = DstType == hipMemoryTypeDevice
160159
? *static_cast<hipDeviceptr_t *>(DstDevice)
161160
: DstDevice;
@@ -168,9 +167,27 @@ static ur_result_t enqueueCommandBufferFillHelper(
168167
NodeParams.elementSize = PatternSize;
169168
NodeParams.height = N;
170169
NodeParams.pitch = PatternSize;
171-
NodeParams.value = Value;
172170
NodeParams.width = 1;
173171

172+
// pattern size in bytes
173+
switch (PatternSize) {
174+
case 1: {
175+
auto Value = *static_cast<const uint8_t *>(Pattern);
176+
NodeParams.value = Value;
177+
break;
178+
}
179+
case 2: {
180+
auto Value = *static_cast<const uint16_t *>(Pattern);
181+
NodeParams.value = Value;
182+
break;
183+
}
184+
case 4: {
185+
auto Value = *static_cast<const uint32_t *>(Pattern);
186+
NodeParams.value = Value;
187+
break;
188+
}
189+
}
190+
174191
UR_CHECK_ERROR(hipGraphAddMemsetNode(&GraphNode, CommandBuffer->HIPGraph,
175192
DepsList.data(), DepsList.size(),
176193
&NodeParams));
@@ -187,15 +204,10 @@ static ur_result_t enqueueCommandBufferFillHelper(
187204
// This means that one hipGraphAddMemsetNode call is made for every 1
188205
// bytes in the pattern.
189206

190-
// List to handle inter-node dependencies
191-
std::vector<hipGraphNode_t> HIPNodesList = {};
192-
// List shared pointer that will point to the last node created
193-
std::shared_ptr<hipGraphNode_t> GraphNodePtr;
194-
195207
size_t NumberOfSteps = PatternSize / sizeof(uint8_t);
196208

197-
// take 4 bytes of the pattern
198-
auto ValueFirst = *(static_cast<const uint32_t *>(Pattern));
209+
// Shared pointer that will point to the last node created
210+
std::shared_ptr<hipGraphNode_t> GraphNodePtr;
199211

200212
// Create a new node
201213
hipGraphNode_t GraphNodeFirst;
@@ -205,7 +217,7 @@ static ur_result_t enqueueCommandBufferFillHelper(
205217
NodeParamsStepFirst.elementSize = 4;
206218
NodeParamsStepFirst.height = Size / sizeof(uint32_t);
207219
NodeParamsStepFirst.pitch = 4;
208-
NodeParamsStepFirst.value = ValueFirst;
220+
NodeParamsStepFirst.value = *(static_cast<const uint32_t *>(Pattern));
209221
NodeParamsStepFirst.width = 1;
210222

211223
UR_CHECK_ERROR(hipGraphAddMemsetNode(
@@ -216,7 +228,8 @@ static ur_result_t enqueueCommandBufferFillHelper(
216228
*SyncPoint = CommandBuffer->addSyncPoint(
217229
std::make_shared<hipGraphNode_t>(GraphNodeFirst));
218230

219-
HIPNodesList.push_back(GraphNodeFirst);
231+
DepsList.clear();
232+
DepsList.push_back(GraphNodeFirst);
220233

221234
// we walk up the pattern in 1-byte steps, and add Memset node for each
222235
// 1-byte chunk of the pattern.
@@ -233,22 +246,22 @@ static ur_result_t enqueueCommandBufferFillHelper(
233246
// Update NodeParam
234247
hipMemsetParams NodeParamsStep = {};
235248
NodeParamsStep.dst = reinterpret_cast<void *>(OffsetPtr);
236-
NodeParamsStep.elementSize = 1;
249+
NodeParamsStep.elementSize = sizeof(uint8_t);
237250
NodeParamsStep.height = Size / NumberOfSteps;
238251
NodeParamsStep.pitch = NumberOfSteps * sizeof(uint8_t);
239252
NodeParamsStep.value = Value;
240253
NodeParamsStep.width = 1;
241254

242255
UR_CHECK_ERROR(hipGraphAddMemsetNode(
243-
&GraphNode, CommandBuffer->HIPGraph, HIPNodesList.data(),
244-
HIPNodesList.size(), &NodeParamsStep));
256+
&GraphNode, CommandBuffer->HIPGraph, DepsList.data(),
257+
DepsList.size(), &NodeParamsStep));
245258

246259
GraphNodePtr = std::make_shared<hipGraphNode_t>(GraphNode);
247260
// Get sync point and register the node with it.
248261
*SyncPoint = CommandBuffer->addSyncPoint(GraphNodePtr);
249262

250-
HIPNodesList.clear();
251-
HIPNodesList.push_back(*GraphNodePtr.get());
263+
DepsList.clear();
264+
DepsList.push_back(*GraphNodePtr.get());
252265
}
253266
}
254267
} catch (ur_result_t Err) {

source/adapters/hip/kernel.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,7 @@ urKernelSetArgMemObj(ur_kernel_handle_t hKernel, uint32_t argIndex,
277277
ur_result_t Result = UR_RESULT_SUCCESS;
278278
try {
279279
auto Device = hKernel->getProgram()->getDevice();
280-
hKernel->Args.addMemObjArg(argIndex, hArgValue,
281-
Properties ? Properties->memoryAccess : 0);
280+
hKernel->Args.addMemObjArg(argIndex, hArgValue, Properties->memoryAccess);
282281
if (hArgValue->isImage()) {
283282
auto array = std::get<SurfaceMem>(hArgValue->Mem).getArray(Device);
284283
hipArray_Format Format{};

0 commit comments

Comments
 (0)