Skip to content

Commit 25b0843

Browse files
committed
Introduce kernel command update support
Use the executable graph update functionality from HIP graph to implement UR kernel command update. Updates also required to the CTS tests to account for a different number of accessor arguments required.
1 parent b685812 commit 25b0843

File tree

8 files changed

+476
-115
lines changed

8 files changed

+476
-115
lines changed

source/adapters/hip/command_buffer.cpp

Lines changed: 291 additions & 56 deletions
Large diffs are not rendered by default.

source/adapters/hip/command_buffer.hpp

Lines changed: 101 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -175,44 +175,137 @@ static inline const char *getUrResultString(ur_result_t Result) {
175175
fprintf(stderr, "UR <--- %s(%s)\n", #Call, getUrResultString(Result)); \
176176
}
177177

178+
// Handle to a kernel command.
179+
//
180+
// Struct that stores all the information related to a kernel command in a
181+
// command-buffer, such that the command can be recreated. When handles can
182+
// be returned from other command types this struct will need refactored.
183+
struct ur_exp_command_buffer_command_handle_t_ {
184+
ur_exp_command_buffer_command_handle_t_(
185+
ur_exp_command_buffer_handle_t CommandBuffer, ur_kernel_handle_t Kernel,
186+
std::shared_ptr<hipGraphNode_t> &&Node, hipKernelNodeParams Params,
187+
uint32_t WorkDim, const size_t *GlobalWorkOffsetPtr,
188+
const size_t *GlobalWorkSizePtr, const size_t *LocalWorkSizePtr);
189+
190+
void setGlobalOffset(const size_t *GlobalWorkOffsetPtr) {
191+
const size_t CopySize = sizeof(size_t) * WorkDim;
192+
std::memcpy(GlobalWorkOffset, GlobalWorkOffsetPtr, CopySize);
193+
if (WorkDim < 3) {
194+
const size_t ZeroSize = sizeof(size_t) * (3 - WorkDim);
195+
std::memset(GlobalWorkOffset + WorkDim, 0, ZeroSize);
196+
}
197+
}
198+
199+
void setGlobalSize(const size_t *GlobalWorkSizePtr) {
200+
const size_t CopySize = sizeof(size_t) * WorkDim;
201+
std::memcpy(GlobalWorkSize, GlobalWorkSizePtr, CopySize);
202+
if (WorkDim < 3) {
203+
const size_t ZeroSize = sizeof(size_t) * (3 - WorkDim);
204+
std::memset(GlobalWorkSize + WorkDim, 0, ZeroSize);
205+
}
206+
}
207+
208+
void setLocalSize(const size_t *LocalWorkSizePtr) {
209+
const size_t CopySize = sizeof(size_t) * WorkDim;
210+
std::memcpy(LocalWorkSize, LocalWorkSizePtr, CopySize);
211+
if (WorkDim < 3) {
212+
const size_t ZeroSize = sizeof(size_t) * (3 - WorkDim);
213+
std::memset(LocalWorkSize + WorkDim, 0, ZeroSize);
214+
}
215+
}
216+
217+
uint32_t incrementInternalReferenceCount() noexcept {
218+
return ++RefCountInternal;
219+
}
220+
uint32_t decrementInternalReferenceCount() noexcept {
221+
return --RefCountInternal;
222+
}
223+
224+
uint32_t incrementExternalReferenceCount() noexcept {
225+
return ++RefCountExternal;
226+
}
227+
uint32_t decrementExternalReferenceCount() noexcept {
228+
return --RefCountExternal;
229+
}
230+
uint32_t getExternalReferenceCount() const noexcept {
231+
return RefCountExternal;
232+
}
233+
234+
ur_exp_command_buffer_handle_t CommandBuffer;
235+
ur_kernel_handle_t Kernel;
236+
std::shared_ptr<hipGraphNode_t> Node;
237+
hipKernelNodeParams Params;
238+
239+
uint32_t WorkDim;
240+
size_t GlobalWorkOffset[3];
241+
size_t GlobalWorkSize[3];
242+
size_t LocalWorkSize[3];
243+
244+
private:
245+
std::atomic_uint32_t RefCountInternal;
246+
std::atomic_uint32_t RefCountExternal;
247+
};
248+
178249
struct ur_exp_command_buffer_handle_t_ {
179250

180251
ur_exp_command_buffer_handle_t_(ur_context_handle_t hContext,
181-
ur_device_handle_t hDevice);
252+
ur_device_handle_t hDevice, bool IsUpdatable);
182253

183254
~ur_exp_command_buffer_handle_t_();
184255

185-
void RegisterSyncPoint(ur_exp_command_buffer_sync_point_t SyncPoint,
256+
void registerSyncPoint(ur_exp_command_buffer_sync_point_t SyncPoint,
186257
std::shared_ptr<hipGraphNode_t> HIPNode) {
187258
SyncPoints[SyncPoint] = HIPNode;
188259
NextSyncPoint++;
189260
}
190261

191-
ur_exp_command_buffer_sync_point_t GetNextSyncPoint() const {
262+
ur_exp_command_buffer_sync_point_t getNextSyncPoint() const {
192263
return NextSyncPoint;
193264
}
194265

195266
// Helper to register next sync point
196267
// @param HIPNode Node to register as next sync point
197268
// @return Pointer to the sync that registers the Node
198269
ur_exp_command_buffer_sync_point_t
199-
AddSyncPoint(std::shared_ptr<hipGraphNode_t> HIPNode) {
270+
addSyncPoint(std::shared_ptr<hipGraphNode_t> HIPNode) {
200271
ur_exp_command_buffer_sync_point_t SyncPoint = NextSyncPoint;
201-
RegisterSyncPoint(SyncPoint, HIPNode);
272+
registerSyncPoint(SyncPoint, HIPNode);
202273
return SyncPoint;
203274
}
275+
uint32_t incrementInternalReferenceCount() noexcept {
276+
return ++RefCountInternal;
277+
}
278+
uint32_t decrementInternalReferenceCount() noexcept {
279+
return --RefCountInternal;
280+
}
281+
uint32_t getInternalReferenceCount() const noexcept {
282+
return RefCountInternal;
283+
}
284+
285+
uint32_t incrementExternalReferenceCount() noexcept {
286+
return ++RefCountExternal;
287+
}
288+
uint32_t decrementExternalReferenceCount() noexcept {
289+
return --RefCountExternal;
290+
}
291+
uint32_t getExternalReferenceCount() const noexcept {
292+
return RefCountExternal;
293+
}
204294

205295
// UR context associated with this command-buffer
206296
ur_context_handle_t Context;
207297
// Device associated with this command buffer
208298
ur_device_handle_t Device;
299+
// Whether commands in the command-buffer can be updated
300+
bool IsUpdatable;
209301
// HIP Graph handle
210302
hipGraph_t HIPGraph;
211303
// HIP Graph Exec handle
212304
hipGraphExec_t HIPGraphExec;
213305
// Atomic variable counting the number of reference to this command_buffer
214306
// using std::atomic prevents data race when incrementing/decrementing.
215-
std::atomic_uint32_t RefCount;
307+
std::atomic_uint32_t RefCountInternal;
308+
std::atomic_uint32_t RefCountExternal;
216309

217310
// Map of sync_points to ur_events
218311
std::unordered_map<ur_exp_command_buffer_sync_point_t,
@@ -222,9 +315,6 @@ struct ur_exp_command_buffer_handle_t_ {
222315
// is not enough)
223316
ur_exp_command_buffer_sync_point_t NextSyncPoint;
224317

225-
// Used when retaining an object.
226-
uint32_t incrementReferenceCount() noexcept { return ++RefCount; }
227-
// Used when releasing an object.
228-
uint32_t decrementReferenceCount() noexcept { return --RefCount; }
229-
uint32_t getReferenceCount() const noexcept { return RefCount; }
318+
// Handles to individual commands in the command-buffer
319+
std::vector<ur_exp_command_buffer_command_handle_t> CommandHandles;
230320
};

source/adapters/hip/device.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -844,9 +844,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
844844
return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
845845

846846
case UR_DEVICE_INFO_COMMAND_BUFFER_SUPPORT_EXP:
847-
return ReturnValue(true);
848847
case UR_DEVICE_INFO_COMMAND_BUFFER_UPDATE_SUPPORT_EXP:
849-
return ReturnValue(false);
848+
return ReturnValue(true);
850849

851850
default:
852851
break;

source/adapters/hip/enqueue.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,22 +1750,19 @@ setKernelParams(const ur_device_handle_t Device, const uint32_t WorkDim,
17501750
const size_t *LocalWorkSize, ur_kernel_handle_t &Kernel,
17511751
hipFunction_t &HIPFunc, size_t (&ThreadsPerBlock)[3],
17521752
size_t (&BlocksPerGrid)[3]) {
1753+
size_t MaxWorkGroupSize = 0;
17531754
ur_result_t Result = UR_RESULT_SUCCESS;
1754-
size_t MaxWorkGroupSize = 0u;
1755-
bool ProvidedLocalWorkGroupSize = LocalWorkSize != nullptr;
1756-
17571755
try {
17581756
ScopedContext Active(Device);
17591757
{
17601758
size_t MaxThreadsPerBlock[3] = {
1761-
hQueue->Device->getMaxBlockDimX(),
1762-
hQueue->Device->getMaxBlockDimY(),
1763-
hQueue->Device->getMaxBlockDimZ()};
1764-
1765-
MaxWorkGroupSize = hQueue->Device->getMaxWorkGroupSize();
1759+
static_cast<size_t>(Device->getMaxBlockDimX()),
1760+
static_cast<size_t>(Device->getMaxBlockDimY()),
1761+
static_cast<size_t>(Device->getMaxBlockDimZ())};
17661762

1763+
MaxWorkGroupSize = Device->getMaxWorkGroupSize();
17671764

1768-
if (ProvidedLocalWorkGroupSize) {
1765+
if (LocalWorkSize != nullptr) {
17691766
auto isValid = [&](int dim) {
17701767
UR_ASSERT(LocalWorkSize[dim] <= MaxThreadsPerBlock[dim],
17711768
UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE);
@@ -1825,7 +1822,7 @@ setKernelParams(const ur_device_handle_t Device, const uint32_t WorkDim,
18251822
: (LocalMemSzPtrPI ? LocalMemSzPtrPI : nullptr);
18261823

18271824
if (LocalMemSzPtr) {
1828-
int DeviceMaxLocalMem = Dev->getDeviceMaxLocalMem();
1825+
int DeviceMaxLocalMem = Device->getDeviceMaxLocalMem();
18291826
static const int EnvVal = std::atoi(LocalMemSzPtr);
18301827
if (EnvVal <= 0 || EnvVal > DeviceMaxLocalMem) {
18311828
setErrorMessage(LocalMemSzPtrUR ? "Invalid value specified for "

source/adapters/hip/kernel.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,8 @@ urKernelSetArgMemObj(ur_kernel_handle_t hKernel, uint32_t argIndex,
277277
ur_result_t Result = UR_RESULT_SUCCESS;
278278
try {
279279
auto Device = hKernel->getProgram()->getDevice();
280-
hKernel->Args.addMemObjArg(argIndex, hArgValue, Properties->memoryAccess);
280+
hKernel->Args.addMemObjArg(argIndex, hArgValue,
281+
Properties ? Properties->memoryAccess : 0);
281282
if (hArgValue->isImage()) {
282283
auto array = std::get<SurfaceMem>(hArgValue->Mem).getArray(Device);
283284
hipArray_Format Format{};

test/conformance/exp_command_buffer/buffer_fill_kernel_update.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,12 @@ TEST_P(BufferFillCommandTest, UpdateParameters) {
8282
};
8383

8484
// Set argument index 2 as new value to fill (index 1 is buffer accessor)
85+
const uint32_t arg_index = (backend == UR_PLATFORM_BACKEND_HIP) ? 4 : 2;
8586
uint32_t new_val = 33;
8687
ur_exp_command_buffer_update_value_arg_desc_t new_input_desc = {
8788
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
8889
nullptr, // pNext
89-
2, // argIndex
90+
arg_index, // argIndex
9091
sizeof(new_val), // argSize
9192
nullptr, // pProperties
9293
&new_val, // hArgValue
@@ -217,10 +218,11 @@ TEST_P(BufferFillCommandTest, SeparateUpdateCalls) {
217218
&output_update_desc));
218219

219220
uint32_t new_val = 33;
221+
const uint32_t arg_index = (backend == UR_PLATFORM_BACKEND_HIP) ? 4 : 2;
220222
ur_exp_command_buffer_update_value_arg_desc_t new_input_desc = {
221223
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
222224
nullptr, // pNext
223-
2, // argIndex
225+
arg_index, // argIndex
224226
sizeof(new_val), // argSize
225227
nullptr, // pProperties
226228
&new_val, // hArgValue
@@ -280,11 +282,12 @@ TEST_P(BufferFillCommandTest, OverrideUpdate) {
280282
ASSERT_SUCCESS(urQueueFinish(queue));
281283
ValidateBuffer(buffer, sizeof(val) * global_size, val);
282284

285+
const uint32_t arg_index = (backend == UR_PLATFORM_BACKEND_HIP) ? 4 : 2;
283286
uint32_t first_val = 33;
284287
ur_exp_command_buffer_update_value_arg_desc_t first_input_desc = {
285288
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
286289
nullptr, // pNext
287-
2, // argIndex
290+
arg_index, // argIndex
288291
sizeof(first_val), // argSize
289292
nullptr, // pProperties
290293
&first_val, // hArgValue
@@ -313,7 +316,7 @@ TEST_P(BufferFillCommandTest, OverrideUpdate) {
313316
ur_exp_command_buffer_update_value_arg_desc_t second_input_desc = {
314317
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
315318
nullptr, // pNext
316-
2, // argIndex
319+
arg_index, // argIndex
317320
sizeof(second_val), // argSize
318321
nullptr, // pProperties
319322
&second_val, // hArgValue
@@ -356,11 +359,12 @@ TEST_P(BufferFillCommandTest, OverrideArgList) {
356359
ValidateBuffer(buffer, sizeof(val) * global_size, val);
357360

358361
ur_exp_command_buffer_update_value_arg_desc_t input_descs[2];
362+
const uint32_t arg_index = (backend == UR_PLATFORM_BACKEND_HIP) ? 4 : 2;
359363
uint32_t first_val = 33;
360364
input_descs[0] = {
361365
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
362366
nullptr, // pNext
363-
2, // argIndex
367+
arg_index, // argIndex
364368
sizeof(first_val), // argSize
365369
nullptr, // pProperties
366370
&first_val, // hArgValue
@@ -370,7 +374,7 @@ TEST_P(BufferFillCommandTest, OverrideArgList) {
370374
input_descs[1] = {
371375
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
372376
nullptr, // pNext
373-
2, // argIndex
377+
arg_index, // argIndex
374378
sizeof(second_val), // argSize
375379
nullptr, // pProperties
376380
&second_val, // hArgValue

0 commit comments

Comments
 (0)