Skip to content

Commit b841691

Browse files
authored
Merge pull request #2559 from Bensuo/fix_kernel_arg_indices
[CUDA][HIP] Fix kernel arguments being overwritten when added out of order
2 parents c685944 + 9de10cd commit b841691

File tree

9 files changed

+420
-79
lines changed

9 files changed

+420
-79
lines changed

source/adapters/cuda/command_buffer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
523523
ThreadsPerBlock, BlocksPerGrid));
524524

525525
// Set node param structure with the kernel related data
526-
auto &ArgIndices = hKernel->getArgIndices();
526+
auto &ArgPointers = hKernel->getArgPointers();
527527
CUDA_KERNEL_NODE_PARAMS NodeParams = {};
528528
NodeParams.func = CuFunc;
529529
NodeParams.gridDimX = BlocksPerGrid[0];
@@ -533,7 +533,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
533533
NodeParams.blockDimY = ThreadsPerBlock[1];
534534
NodeParams.blockDimZ = ThreadsPerBlock[2];
535535
NodeParams.sharedMemBytes = LocalSize;
536-
NodeParams.kernelParams = const_cast<void **>(ArgIndices.data());
536+
NodeParams.kernelParams = const_cast<void **>(ArgPointers.data());
537537

538538
// Create and add an new kernel node to the Cuda graph
539539
UR_CHECK_ERROR(cuGraphAddKernelNode(&GraphNode, hCommandBuffer->CudaGraph,
@@ -1398,7 +1398,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
13981398
Params.blockDimZ = ThreadsPerBlock[2];
13991399
Params.sharedMemBytes = KernelCommandHandle->Kernel->getLocalSize();
14001400
Params.kernelParams =
1401-
const_cast<void **>(KernelCommandHandle->Kernel->getArgIndices().data());
1401+
const_cast<void **>(KernelCommandHandle->Kernel->getArgPointers().data());
14021402

14031403
CUgraphNode Node = KernelCommandHandle->Node;
14041404
CUgraphExec CudaGraphExec = CommandBuffer->CudaGraphExec;

source/adapters/cuda/enqueue.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -492,11 +492,11 @@ enqueueKernelLaunch(ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel,
492492
UR_CHECK_ERROR(RetImplEvent->start());
493493
}
494494

495-
auto &ArgIndices = hKernel->getArgIndices();
495+
auto &ArgPointers = hKernel->getArgPointers();
496496
UR_CHECK_ERROR(cuLaunchKernel(
497497
CuFunc, BlocksPerGrid[0], BlocksPerGrid[1], BlocksPerGrid[2],
498498
ThreadsPerBlock[0], ThreadsPerBlock[1], ThreadsPerBlock[2], LocalSize,
499-
CuStream, const_cast<void **>(ArgIndices.data()), nullptr));
499+
CuStream, const_cast<void **>(ArgPointers.data()), nullptr));
500500

501501
if (phEvent) {
502502
UR_CHECK_ERROR(RetImplEvent->record());
@@ -680,7 +680,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
680680
UR_CHECK_ERROR(RetImplEvent->start());
681681
}
682682

683-
auto &ArgIndices = hKernel->getArgIndices();
683+
auto &ArgPointers = hKernel->getArgPointers();
684684

685685
CUlaunchConfig launch_config;
686686
launch_config.gridDimX = BlocksPerGrid[0];
@@ -696,7 +696,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
696696
launch_config.numAttrs = launch_attribute.size();
697697

698698
UR_CHECK_ERROR(cuLaunchKernelEx(&launch_config, CuFunc,
699-
const_cast<void **>(ArgIndices.data()),
699+
const_cast<void **>(ArgPointers.data()),
700700
nullptr));
701701

702702
if (phEvent) {

source/adapters/cuda/kernel.hpp

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,10 @@ struct ur_kernel_handle_t_ {
6666
args_t Storage;
6767
/// Aligned size of each parameter, including padding.
6868
args_size_t ParamSizes;
69-
/// Byte offset into /p Storage allocation for each parameter.
70-
args_index_t Indices;
69+
/// Byte offset into /p Storage allocation for each argument.
70+
args_index_t ArgPointers;
71+
/// Position in the Storage array where the next argument should added.
72+
size_t InsertPos = 0;
7173
/// Aligned size in bytes for each local memory parameter after padding has
7274
/// been added. Zero if the argument at the index isn't a local memory
7375
/// argument.
@@ -90,33 +92,43 @@ struct ur_kernel_handle_t_ {
9092
std::uint32_t ImplicitOffsetArgs[3] = {0, 0, 0};
9193

9294
arguments() {
93-
// Place the implicit offset index at the end of the indicies collection
94-
Indices.emplace_back(&ImplicitOffsetArgs);
95+
// Place the implicit offset index at the end of the ArgPointers
96+
// collection.
97+
ArgPointers.emplace_back(&ImplicitOffsetArgs);
9598
}
9699

97100
/// Add an argument to the kernel.
98101
/// If the argument existed before, it is replaced.
99102
/// Otherwise, it is added.
100103
/// Gaps are filled with empty arguments.
101-
/// Implicit offset argument is kept at the back of the indices collection.
104+
/// Implicit offset argument is kept at the back of the ArgPointers
105+
/// collection.
102106
void addArg(size_t Index, size_t Size, const void *Arg,
103107
size_t LocalSize = 0) {
104-
if (Index + 2 > Indices.size()) {
108+
// Expand storage to accommodate this Index if needed.
109+
if (Index + 2 > ArgPointers.size()) {
105110
// Move implicit offset argument index with the end
106-
Indices.resize(Index + 2, Indices.back());
111+
ArgPointers.resize(Index + 2, ArgPointers.back());
107112
// Ensure enough space for the new argument
108113
ParamSizes.resize(Index + 1);
109114
AlignedLocalMemSize.resize(Index + 1);
110115
OriginalLocalMemSize.resize(Index + 1);
111116
}
112-
ParamSizes[Index] = Size;
113-
// calculate the insertion point on the array
114-
size_t InsertPos = std::accumulate(std::begin(ParamSizes),
115-
std::begin(ParamSizes) + Index, 0);
116-
// Update the stored value for the argument
117-
std::memcpy(&Storage[InsertPos], Arg, Size);
118-
Indices[Index] = &Storage[InsertPos];
119-
AlignedLocalMemSize[Index] = LocalSize;
117+
118+
// Copy new argument to storage if it hasn't been added before.
119+
if (ParamSizes[Index] == 0) {
120+
ParamSizes[Index] = Size;
121+
std::memcpy(&Storage[InsertPos], Arg, Size);
122+
ArgPointers[Index] = &Storage[InsertPos];
123+
AlignedLocalMemSize[Index] = LocalSize;
124+
InsertPos += Size;
125+
}
126+
// Otherwise, update the existing argument.
127+
else {
128+
std::memcpy(ArgPointers[Index], Arg, Size);
129+
AlignedLocalMemSize[Index] = LocalSize;
130+
assert(Size == ParamSizes[Index]);
131+
}
120132
}
121133

122134
/// Returns the padded size and offset of a local memory argument.
@@ -128,7 +140,7 @@ struct ur_kernel_handle_t_ {
128140
std::pair<size_t, size_t> calcAlignedLocalArgument(size_t Index,
129141
size_t Size) {
130142
// Store the unpadded size of the local argument
131-
if (Index + 2 > Indices.size()) {
143+
if (Index + 2 > ArgPointers.size()) {
132144
AlignedLocalMemSize.resize(Index + 1);
133145
OriginalLocalMemSize.resize(Index + 1);
134146
}
@@ -158,10 +170,11 @@ struct ur_kernel_handle_t_ {
158170
return std::make_pair(AlignedLocalSize, AlignedLocalOffset);
159171
}
160172

161-
// Iterate over all existing local argument which follows StartIndex
173+
// Iterate over each existing local argument which follows StartIndex
162174
// index, update the offset and pointer into the kernel local memory.
163175
void updateLocalArgOffset(size_t StartIndex) {
164-
const size_t NumArgs = Indices.size() - 1; // Accounts for implicit arg
176+
const size_t NumArgs =
177+
ArgPointers.size() - 1; // Accounts for implicit arg
165178
for (auto SuccIndex = StartIndex; SuccIndex < NumArgs; SuccIndex++) {
166179
const size_t OriginalLocalSize = OriginalLocalMemSize[SuccIndex];
167180
if (OriginalLocalSize == 0) {
@@ -177,10 +190,7 @@ struct ur_kernel_handle_t_ {
177190
AlignedLocalMemSize[SuccIndex] = SuccAlignedLocalSize;
178191

179192
// Store new offset into local data
180-
const size_t InsertPos =
181-
std::accumulate(std::begin(ParamSizes),
182-
std::begin(ParamSizes) + SuccIndex, size_t{0});
183-
std::memcpy(&Storage[InsertPos], &SuccAlignedLocalOffset,
193+
std::memcpy(ArgPointers[SuccIndex], &SuccAlignedLocalOffset,
184194
sizeof(size_t));
185195
}
186196
}
@@ -228,7 +238,7 @@ struct ur_kernel_handle_t_ {
228238
std::memcpy(ImplicitOffsetArgs, ImplicitOffset, Size);
229239
}
230240

231-
const args_index_t &getIndices() const noexcept { return Indices; }
241+
const args_index_t &getArgPointers() const noexcept { return ArgPointers; }
232242

233243
uint32_t getLocalSize() const {
234244
return std::accumulate(std::begin(AlignedLocalMemSize),
@@ -299,7 +309,7 @@ struct ur_kernel_handle_t_ {
299309
/// real one required by the kernel, since this cannot be queried from
300310
/// the CUDA Driver API
301311
uint32_t getNumArgs() const noexcept {
302-
return static_cast<uint32_t>(Args.Indices.size() - 1);
312+
return static_cast<uint32_t>(Args.ArgPointers.size() - 1);
303313
}
304314

305315
void setKernelArg(int Index, size_t Size, const void *Arg) {
@@ -314,8 +324,8 @@ struct ur_kernel_handle_t_ {
314324
return Args.setImplicitOffset(Size, ImplicitOffset);
315325
}
316326

317-
const arguments::args_index_t &getArgIndices() const {
318-
return Args.getIndices();
327+
const arguments::args_index_t &getArgPointers() const {
328+
return Args.getArgPointers();
319329
}
320330

321331
void setWorkGroupMemory(size_t MemSize) { Args.setWorkGroupMemory(MemSize); }

source/adapters/hip/command_buffer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
378378
pLocalWorkSize, hKernel, HIPFunc, ThreadsPerBlock, BlocksPerGrid));
379379

380380
// Set node param structure with the kernel related data
381-
auto &ArgIndices = hKernel->getArgIndices();
381+
auto &ArgPointers = hKernel->getArgPointers();
382382
hipKernelNodeParams NodeParams;
383383
NodeParams.func = HIPFunc;
384384
NodeParams.gridDim.x = BlocksPerGrid[0];
@@ -388,7 +388,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
388388
NodeParams.blockDim.y = ThreadsPerBlock[1];
389389
NodeParams.blockDim.z = ThreadsPerBlock[2];
390390
NodeParams.sharedMemBytes = LocalSize;
391-
NodeParams.kernelParams = const_cast<void **>(ArgIndices.data());
391+
NodeParams.kernelParams = const_cast<void **>(ArgPointers.data());
392392
NodeParams.extra = nullptr;
393393

394394
// Create and add an new kernel node to the HIP graph
@@ -1098,7 +1098,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
10981098
Params.blockDim.z = ThreadsPerBlock[2];
10991099
Params.sharedMemBytes = hCommand->Kernel->getLocalSize();
11001100
Params.kernelParams =
1101-
const_cast<void **>(hCommand->Kernel->getArgIndices().data());
1101+
const_cast<void **>(hCommand->Kernel->getArgPointers().data());
11021102

11031103
hipGraphNode_t Node = hCommand->Node;
11041104
hipGraphExec_t HipGraphExec = CommandBuffer->HIPGraphExec;

source/adapters/hip/enqueue.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
308308
}
309309
}
310310

311-
auto ArgIndices = hKernel->getArgIndices();
311+
auto ArgPointers = hKernel->getArgPointers();
312312

313313
// If migration of mem across buffer is needed, an event must be associated
314314
// with this command, implicitly if phEvent is nullptr
@@ -322,7 +322,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
322322
UR_CHECK_ERROR(hipModuleLaunchKernel(
323323
HIPFunc, BlocksPerGrid[0], BlocksPerGrid[1], BlocksPerGrid[2],
324324
ThreadsPerBlock[0], ThreadsPerBlock[1], ThreadsPerBlock[2],
325-
hKernel->getLocalSize(), HIPStream, ArgIndices.data(), nullptr));
325+
hKernel->getLocalSize(), HIPStream, ArgPointers.data(), nullptr));
326326

327327
if (phEvent) {
328328
UR_CHECK_ERROR(RetImplEvent->record());

0 commit comments

Comments
 (0)